游雁
2024-03-27 9b4e9cc8a0311e5243d69b73ed073e7ea441982e
funasr/frontends/whisper_frontend.py
@@ -1,8 +1,8 @@
from typing import Tuple
import torch
import torch.nn as nn
import whisper
from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
from funasr.register import tables
from torch.nn.utils.rnn import pad_sequence
@@ -17,30 +17,34 @@
    def __init__(
            self,
            fs: int = 16000,
            whisper_model: str = "large-v3",
            whisper_model: str = None,
            do_pad_trim: bool = True,
            n_mels: int = 80,
            permute: bool = False,
            **kwargs,
    ):
        super().__init__()
        assert fs == 16000
        self.fs = fs
        import whisper
        from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
        self.n_fft = N_FFT
        self.win_length = N_FFT
        self.hop_length = HOP_LENGTH
        self.pad_samples = N_SAMPLES
        self.frame_shift = self.hop_length
        self.lfr_n = 1
        self.n_mels = n_mels
        if whisper_model == "large-v3" or whisper_model == "large":
            self.n_mels = 128
        else:
            self.n_mels = 80
        self.mel_filters = whisper.audio.mel_filters
        self.do_pad_trim = do_pad_trim
        if do_pad_trim:
            self.pad_or_trim = whisper.pad_or_trim
        self.permute = permute
        assert whisper_model in whisper.available_models()
        # assert whisper_model in whisper.available_models()
    def output_size(self) -> int:
        return self.n_mels
@@ -77,7 +81,7 @@
        return log_spec, olens
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor
            self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = input.size(0)
        feats = []
@@ -98,5 +102,6 @@
            feats_pad = pad_sequence(feats,
                                     batch_first=True,
                                     padding_value=0.0)
        if self.permute:
            feats_pad = feats_pad.permute(0, 2, 1)
        return feats_pad, feats_lens