| | |
| | | |
| | | class DefaultFrontend(AbsFrontend): |
| | | """Conventional frontend structure for ASR. |
| | | |
| | | Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN |
| | | """ |
| | | |
| | |
| | | if input_stft.dim() == 4: |
| | | # h: (B, T, C, F) -> h: (B, T, F) |
| | | if self.training: |
| | | if self.use_channel == None: |
| | | input_stft = input_stft[:, :, 0, :] |
| | | if self.use_channel is not None: |
| | | input_stft = input_stft[:, :, self.use_channel, :] |
| | | else: |
| | | # Select 1ch randomly |
| | | ch = np.random.randint(input_stft.size(2)) |
| | |
| | | |
| | | class MultiChannelFrontend(AbsFrontend): |
| | | """Conventional frontend structure for ASR. |
| | | |
| | | Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN |
| | | """ |
| | | |
| | |
| | | # Change torch.Tensor to ComplexTensor |
| | | # input_stft: (..., F, 2) -> (..., F) |
| | | input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) |
| | | return input_stft, feats_lens |
| | | return input_stft, feats_lens |