1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
| from typing import List
| from typing import Optional
| from typing import Tuple
| from typing import Union
|
| import numpy
| import torch
| import torch.nn as nn
| from torch_complex.tensor import ComplexTensor
|
| from funasr.frontends.utils.dnn_beamformer import DNN_Beamformer
| from funasr.frontends.utils.dnn_wpe import DNN_WPE
|
|
| class Frontend(nn.Module):
| def __init__(
| self,
| idim: int,
| # WPE options
| use_wpe: bool = False,
| wtype: str = "blstmp",
| wlayers: int = 3,
| wunits: int = 300,
| wprojs: int = 320,
| wdropout_rate: float = 0.0,
| taps: int = 5,
| delay: int = 3,
| use_dnn_mask_for_wpe: bool = True,
| # Beamformer options
| use_beamformer: bool = False,
| btype: str = "blstmp",
| blayers: int = 3,
| bunits: int = 300,
| bprojs: int = 320,
| bnmask: int = 2,
| badim: int = 320,
| ref_channel: int = -1,
| bdropout_rate=0.0,
| ):
| super().__init__()
|
| self.use_beamformer = use_beamformer
| self.use_wpe = use_wpe
| self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
| # use frontend for all the data,
| # e.g. in the case of multi-speaker speech separation
| self.use_frontend_for_all = bnmask > 2
|
| if self.use_wpe:
| if self.use_dnn_mask_for_wpe:
| # Use DNN for power estimation
| # (Not observed significant gains)
| iterations = 1
| else:
| # Performing as conventional WPE, without DNN Estimator
| iterations = 2
|
| self.wpe = DNN_WPE(
| wtype=wtype,
| widim=idim,
| wunits=wunits,
| wprojs=wprojs,
| wlayers=wlayers,
| taps=taps,
| delay=delay,
| dropout_rate=wdropout_rate,
| iterations=iterations,
| use_dnn_mask=use_dnn_mask_for_wpe,
| )
| else:
| self.wpe = None
|
| if self.use_beamformer:
| self.beamformer = DNN_Beamformer(
| btype=btype,
| bidim=idim,
| bunits=bunits,
| bprojs=bprojs,
| blayers=blayers,
| bnmask=bnmask,
| dropout_rate=bdropout_rate,
| badim=badim,
| ref_channel=ref_channel,
| )
| else:
| self.beamformer = None
|
| def forward(
| self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
| ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
| assert len(x) == len(ilens), (len(x), len(ilens))
| # (B, T, F) or (B, T, C, F)
| if x.dim() not in (3, 4):
| raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
| if not torch.is_tensor(ilens):
| ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
|
| mask = None
| h = x
| if h.dim() == 4:
| if self.training:
| choices = [(False, False)] if not self.use_frontend_for_all else []
| if self.use_wpe:
| choices.append((True, False))
|
| if self.use_beamformer:
| choices.append((False, True))
|
| use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
|
| else:
| use_wpe = self.use_wpe
| use_beamformer = self.use_beamformer
|
| # 1. WPE
| if use_wpe:
| # h: (B, T, C, F) -> h: (B, T, C, F)
| h, ilens, mask = self.wpe(h, ilens)
|
| # 2. Beamformer
| if use_beamformer:
| # h: (B, T, C, F) -> h: (B, T, F)
| h, ilens, mask = self.beamformer(h, ilens)
|
| return h, ilens, mask
|
|
| def frontend_for(args, idim):
| return Frontend(
| idim=idim,
| # WPE options
| use_wpe=args.use_wpe,
| wtype=args.wtype,
| wlayers=args.wlayers,
| wunits=args.wunits,
| wprojs=args.wprojs,
| wdropout_rate=args.wdropout_rate,
| taps=args.wpe_taps,
| delay=args.wpe_delay,
| use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
| # Beamformer options
| use_beamformer=args.use_beamformer,
| btype=args.btype,
| blayers=args.blayers,
| bunits=args.bunits,
| bprojs=args.bprojs,
| bnmask=args.bnmask,
| badim=args.badim,
| ref_channel=args.ref_channel,
| bdropout_rate=args.bdropout_rate,
| )
|
|