游雁
2023-05-25 b18f7d121f2f17df8bf2d0c2bbb223bc5ddbcc0f
funasr/models/frontend/default.py
@@ -18,7 +18,6 @@
class DefaultFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
    """
@@ -38,6 +37,7 @@
            htk: bool = False,
            frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
            apply_stft: bool = True,
            use_channel: int = None,
    ):
        assert check_argument_types()
        super().__init__()
@@ -77,6 +77,7 @@
        )
        self.n_mels = n_mels
        self.frontend_type = "default"
        self.use_channel = use_channel
    def output_size(self) -> int:
        return self.n_mels
@@ -100,9 +101,12 @@
        if input_stft.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                # Select 1ch randomly
                ch = np.random.randint(input_stft.size(2))
                input_stft = input_stft[:, :, ch, :]
                if self.use_channel is not None:
                    input_stft = input_stft[:, :, self.use_channel, :]
                else:
                    # Select 1ch randomly
                    ch = np.random.randint(input_stft.size(2))
                    input_stft = input_stft[:, :, ch, :]
            else:
                # Use the first channel
                input_stft = input_stft[:, :, 0, :]
@@ -137,7 +141,6 @@
class MultiChannelFrontend(AbsFrontend):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
    """
@@ -255,4 +258,4 @@
        # Change torch.Tensor to ComplexTensor
        # input_stft: (..., F, 2) -> (..., F)
        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
        return input_stft, feats_lens
        return input_stft, feats_lens