From e84f17adca2d8a81bc2d0229b9531e7eb0a7705c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 26 三月 2024 12:34:26 +0800
Subject: [PATCH] update
---
funasr/frontends/whisper_frontend.py | 15 ++++++++++-----
1 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index 9290a25..dd61f8e 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -1,8 +1,8 @@
from typing import Tuple
import torch
import torch.nn as nn
-import whisper
-from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
+
+
from funasr.register import tables
from torch.nn.utils.rnn import pad_sequence
@@ -20,11 +20,14 @@
whisper_model: str = None,
do_pad_trim: bool = True,
n_mels: int = 80,
+ permute: bool = False,
+ **kwargs,
):
super().__init__()
assert fs == 16000
self.fs = fs
-
+ import whisper
+ from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
self.n_fft = N_FFT
self.win_length = N_FFT
self.hop_length = HOP_LENGTH
@@ -39,6 +42,7 @@
self.do_pad_trim = do_pad_trim
if do_pad_trim:
self.pad_or_trim = whisper.pad_or_trim
+ self.permute = permute
# assert whisper_model in whisper.available_models()
@@ -77,7 +81,7 @@
return log_spec, olens
def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
+ self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
@@ -98,5 +102,6 @@
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
-
+ if self.permute:
+ feats_pad = feats_pad.permute(0, 2, 1)
return feats_pad, feats_lens
\ No newline at end of file
--
Gitblit v1.9.1