| | |
| | | from typing import Tuple |
| | | import torch |
| | | import torch.nn as nn |
| | | import whisper |
| | | from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES |
| | | |
| | | |
| | | from funasr.register import tables |
| | | from torch.nn.utils.rnn import pad_sequence |
| | | |
| | |
| | | super().__init__() |
| | | assert fs == 16000 |
| | | self.fs = fs |
| | | |
| | | import whisper |
| | | from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES |
| | | self.n_fft = N_FFT |
| | | self.win_length = N_FFT |
| | | self.hop_length = HOP_LENGTH |
| | |
| | | if whisper_model == "large-v3" or whisper_model == "large": |
| | | self.n_mels = 128 |
| | | |
| | | self.mel_filters = whisper.audio.mel_filters |
| | | filters_path = kwargs.get("filters_path", None) |
| | | self.filters_path = filters_path |
| | | if filters_path is not None: |
| | | from funasr.models.sense_voice.whisper_lib.audio import mel_filters |
| | | self.mel_filters = mel_filters |
| | | else: |
| | | self.mel_filters = whisper.audio.mel_filters |
| | | self.do_pad_trim = do_pad_trim |
| | | if do_pad_trim: |
| | | self.pad_or_trim = whisper.pad_or_trim |
| | |
| | | |
| | | # whisper deletes the last frame by default (Shih-Lun) |
| | | magnitudes = stft[..., :-1].abs() ** 2 |
| | | |
| | | filters = self.mel_filters(audio.device, self.n_mels) |
| | | if self.filters_path is not None: |
| | | filters = self.mel_filters(audio.device, self.n_mels, self.filters_path) |
| | | else: |
| | | filters = self.mel_filters(audio.device, self.n_mels) |
| | | mel_spec = filters @ magnitudes |
| | | |
| | | log_spec = torch.clamp(mel_spec, min=1e-10).log10() |
| | |
| | | batch_size = input.size(0) |
| | | feats = [] |
| | | feats_lens = [] |
| | | input = input.to(torch.float32) |
| | | for i in range(batch_size): |
| | | if self.do_pad_trim: |
| | | feat = self.pad_or_trim(input[i], self.pad_samples) |