zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/frontends/whisper_frontend.py
@@ -28,6 +28,7 @@
        self.fs = fs
        import whisper
        from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
        self.n_fft = N_FFT
        self.win_length = N_FFT
        self.hop_length = HOP_LENGTH
@@ -42,6 +43,7 @@
        self.filters_path = filters_path
        if filters_path is not None:
            from funasr.models.sense_voice.whisper_lib.audio import mel_filters
            self.mel_filters = mel_filters
        else:
            self.mel_filters = whisper.audio.mel_filters
@@ -61,9 +63,7 @@
            ilens: torch.Tensor = None,
    ) -> torch.Tensor:
        window = torch.hann_window(self.win_length).to(audio.device)
        stft = torch.stft(
            audio, self.n_fft, self.hop_length, window=window, return_complex=True
        )
        stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)
        # whisper deletes the last frame by default (Shih-Lun)
        magnitudes = stft[..., :-1].abs() ** 2
@@ -89,7 +89,10 @@
        return log_spec, olens
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
        self,
        input: torch.Tensor,
        input_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
@@ -108,9 +111,7 @@
        if batch_size == 1:
            feats_pad = feats[0][None, :, :]
        else:
            feats_pad = pad_sequence(feats,
                                     batch_first=True,
                                     padding_value=0.0)
            feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
        if self.permute:
            feats_pad = feats_pad.permute(0, 2, 1)
        return feats_pad, feats_lens