From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/frontends/whisper_frontend.py | 39 ++++++++++++++++++++-------------------
1 files changed, 20 insertions(+), 19 deletions(-)
diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index acc99af..1bd8aec 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -15,24 +15,25 @@
"""
def __init__(
- self,
- fs: int = 16000,
- whisper_model: str = None,
- do_pad_trim: bool = True,
- n_mels: int = 80,
- permute: bool = False,
- **kwargs,
+ self,
+ fs: int = 16000,
+ whisper_model: str = None,
+ do_pad_trim: bool = True,
+ n_mels: int = 80,
+ permute: bool = False,
+ **kwargs,
):
super().__init__()
assert fs == 16000
self.fs = fs
import whisper
from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
+
self.n_fft = N_FFT
self.win_length = N_FFT
self.hop_length = HOP_LENGTH
self.pad_samples = N_SAMPLES
- self.frame_shift = self.hop_length
+ self.frame_shift = int(self.hop_length / self.fs * 1000)
self.lfr_n = 1
self.n_mels = n_mels
if whisper_model == "large-v3" or whisper_model == "large":
@@ -42,6 +43,7 @@
self.filters_path = filters_path
if filters_path is not None:
from funasr.models.sense_voice.whisper_lib.audio import mel_filters
+
self.mel_filters = mel_filters
else:
self.mel_filters = whisper.audio.mel_filters
@@ -56,14 +58,12 @@
return self.n_mels
def log_mel_spectrogram(
- self,
- audio: torch.Tensor,
- ilens: torch.Tensor = None,
+ self,
+ audio: torch.Tensor,
+ ilens: torch.Tensor = None,
) -> torch.Tensor:
window = torch.hann_window(self.win_length).to(audio.device)
- stft = torch.stft(
- audio, self.n_fft, self.hop_length, window=window, return_complex=True
- )
+ stft = torch.stft(audio, self.n_fft, self.hop_length, window=window, return_complex=True)
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
@@ -89,7 +89,10 @@
return log_spec, olens
def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
+ self,
+ input: torch.Tensor,
+ input_lengths: torch.Tensor,
+ **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
@@ -108,9 +111,7 @@
if batch_size == 1:
feats_pad = feats[0][None, :, :]
else:
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0)
if self.permute:
feats_pad = feats_pad.permute(0, 2, 1)
- return feats_pad, feats_lens
\ No newline at end of file
+ return feats_pad, feats_lens
--
Gitblit v1.9.1