From 2ac38adbe5f4e1374a079e032ed4b504351a207c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 四月 2024 18:08:57 +0800
Subject: [PATCH] Dev gzf exp (#1647)

---
 funasr/frontends/whisper_frontend.py |   15 ++++++++++++---
 1 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index dd61f8e..acc99af 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -38,7 +38,13 @@
         if whisper_model == "large-v3" or whisper_model == "large":
             self.n_mels = 128
 
-        self.mel_filters = whisper.audio.mel_filters
+        filters_path = kwargs.get("filters_path", None)
+        self.filters_path = filters_path
+        if filters_path is not None:
+            from funasr.models.sense_voice.whisper_lib.audio import mel_filters
+            self.mel_filters = mel_filters
+        else:
+            self.mel_filters = whisper.audio.mel_filters
         self.do_pad_trim = do_pad_trim
         if do_pad_trim:
             self.pad_or_trim = whisper.pad_or_trim
@@ -61,8 +67,10 @@
 
         # whisper deletes the last frame by default (Shih-Lun)
         magnitudes = stft[..., :-1].abs() ** 2
-
-        filters = self.mel_filters(audio.device, self.n_mels)
+        if self.filters_path is not None:
+            filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
+        else:
+            filters = self.mel_filters(audio.device, self.n_mels)
         mel_spec = filters @ magnitudes
 
         log_spec = torch.clamp(mel_spec, min=1e-10).log10()
@@ -86,6 +94,7 @@
         batch_size = input.size(0)
         feats = []
         feats_lens = []
+        input = input.to(torch.float32)
         for i in range(batch_size):
             if self.do_pad_trim:
                 feat = self.pad_or_trim(input[i], self.pad_samples)

--
Gitblit v1.9.1