From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/frontends/whisper_frontend.py |   65 ++++++++++++++++++++------------
 1 files changed, 40 insertions(+), 25 deletions(-)

diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index 752fd20..1bd8aec 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -1,8 +1,8 @@
 from typing import Tuple
 import torch
 import torch.nn as nn
-import whisper
-from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
+
+
 from funasr.register import tables
 from torch.nn.utils.rnn import pad_sequence
 
@@ -15,50 +15,62 @@
     """
 
     def __init__(
-            self,
-            fs: int = 16000,
-            whisper_model: str = "large-v3",
-            do_pad_trim: bool = True,
+        self,
+        fs: int = 16000,
+        whisper_model: str = None,
+        do_pad_trim: bool = True,
+        n_mels: int = 80,
+        permute: bool = False,
+        **kwargs,
     ):
         super().__init__()
         assert fs == 16000
         self.fs = fs
+        import whisper
+        from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
 
         self.n_fft = N_FFT
         self.win_length = N_FFT
         self.hop_length = HOP_LENGTH
         self.pad_samples = N_SAMPLES
-        self.frame_shift = self.hop_length
+        self.frame_shift = int(self.hop_length / self.fs * 1000)
         self.lfr_n = 1
+        self.n_mels = n_mels
         if whisper_model == "large-v3" or whisper_model == "large":
             self.n_mels = 128
-        else:
-            self.n_mels = 80
 
-        self.mel_filters = whisper.audio.mel_filters
+        filters_path = kwargs.get("filters_path", None)
+        self.filters_path = filters_path
+        if filters_path is not None:
+            from funasr.models.sense_voice.whisper_lib.audio import mel_filters
+
+            self.mel_filters = mel_filters
+        else:
+            self.mel_filters = whisper.audio.mel_filters
         self.do_pad_trim = do_pad_trim
         if do_pad_trim:
             self.pad_or_trim = whisper.pad_or_trim
+        self.permute = permute
 
-        assert whisper_model in whisper.available_models()
+        # assert whisper_model in whisper.available_models()
 
     def output_size(self) -> int:
         return self.n_mels
 
     def log_mel_spectrogram(
-            self,
-            audio: torch.Tensor,
-            ilens: torch.Tensor = None,
+        self,
+        audio: torch.Tensor,
+        ilens: torch.Tensor = None,
     ) -> torch.Tensor:
         window = torch.hann_window(self.win_length).to(audio.device)
-        stft = torch.stft(
-            audio, self.n_fft, self.hop_length, window=window, return_complex=True
-        )
+        stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)
 
         # whisper deletes the last frame by default (Shih-Lun)
         magnitudes = stft[..., :-1].abs() ** 2
-
-        filters = self.mel_filters(audio.device, self.n_mels)
+        if self.filters_path is not None:
+            filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
+        else:
+            filters = self.mel_filters(audio.device, self.n_mels)
         mel_spec = filters @ magnitudes
 
         log_spec = torch.clamp(mel_spec, min=1e-10).log10()
@@ -77,11 +89,15 @@
         return log_spec, olens
 
     def forward(
-            self, input: torch.Tensor, input_lengths: torch.Tensor
+        self,
+        input: torch.Tensor,
+        input_lengths: torch.Tensor,
+        **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         batch_size = input.size(0)
         feats = []
         feats_lens = []
+        input = input.to(torch.float32)
         for i in range(batch_size):
             if self.do_pad_trim:
                 feat = self.pad_or_trim(input[i], self.pad_samples)
@@ -95,8 +111,7 @@
         if batch_size == 1:
             feats_pad = feats[0][None, :, :]
         else:
-            feats_pad = pad_sequence(feats,
-                                     batch_first=True,
-                                     padding_value=0.0)
-
-        return feats_pad, feats_lens
\ No newline at end of file
+            feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
+        if self.permute:
+            feats_pad = feats_pad.permute(0, 2, 1)
+        return feats_pad, feats_lens

--
Gitblit v1.9.1