From 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 17:24:59 +0800
Subject: [PATCH] fix func Forward

---
 funasr/frontends/whisper_frontend.py |    8 ++++++--
 1 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index 9290a25..0598c61 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -20,6 +20,8 @@
             whisper_model: str = None,
             do_pad_trim: bool = True,
             n_mels: int = 80,
+            permute: bool = False,
+            **kwargs,
     ):
         super().__init__()
         assert fs == 16000
@@ -39,6 +41,7 @@
         self.do_pad_trim = do_pad_trim
         if do_pad_trim:
             self.pad_or_trim = whisper.pad_or_trim
+        self.permute = permute
 
         # assert whisper_model in whisper.available_models()
 
@@ -77,7 +80,7 @@
         return log_spec, olens
 
     def forward(
-            self, input: torch.Tensor, input_lengths: torch.Tensor
+            self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         batch_size = input.size(0)
         feats = []
@@ -98,5 +101,6 @@
             feats_pad = pad_sequence(feats,
                                      batch_first=True,
                                      padding_value=0.0)
-
+        if self.permute:
+            feats_pad = feats_pad.permute(0, 2, 1)
         return feats_pad, feats_lens
\ No newline at end of file

--
Gitblit v1.9.1