From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update

---
 funasr/frontends/whisper_frontend.py |   23 ++++++++++++++---------
 1 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index 752fd20..dd61f8e 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -1,8 +1,8 @@
 from typing import Tuple
 import torch
 import torch.nn as nn
-import whisper
-from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
+
+
 from funasr.register import tables
 from torch.nn.utils.rnn import pad_sequence
 
@@ -17,30 +17,34 @@
     def __init__(
             self,
             fs: int = 16000,
-            whisper_model: str = "large-v3",
+            whisper_model: str = None,
             do_pad_trim: bool = True,
+            n_mels: int = 80,
+            permute: bool = False,
+            **kwargs,
     ):
         super().__init__()
         assert fs == 16000
         self.fs = fs
-
+        import whisper
+        from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
         self.n_fft = N_FFT
         self.win_length = N_FFT
         self.hop_length = HOP_LENGTH
         self.pad_samples = N_SAMPLES
         self.frame_shift = self.hop_length
         self.lfr_n = 1
+        self.n_mels = n_mels
         if whisper_model == "large-v3" or whisper_model == "large":
             self.n_mels = 128
-        else:
-            self.n_mels = 80
 
         self.mel_filters = whisper.audio.mel_filters
         self.do_pad_trim = do_pad_trim
         if do_pad_trim:
             self.pad_or_trim = whisper.pad_or_trim
+        self.permute = permute
 
-        assert whisper_model in whisper.available_models()
+        # assert whisper_model in whisper.available_models()
 
     def output_size(self) -> int:
         return self.n_mels
@@ -77,7 +81,7 @@
         return log_spec, olens
 
     def forward(
-            self, input: torch.Tensor, input_lengths: torch.Tensor
+            self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         batch_size = input.size(0)
         feats = []
@@ -98,5 +102,6 @@
             feats_pad = pad_sequence(feats,
                                      batch_first=True,
                                      padding_value=0.0)
-
+        if self.permute:
+            feats_pad = feats_pad.permute(0, 2, 1)
         return feats_pad, feats_lens
\ No newline at end of file

--
Gitblit v1.9.1