From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http client
---
funasr/frontends/whisper_frontend.py | 38 ++++++++++++++++++++++++++------------
1 files changed, 26 insertions(+), 12 deletions(-)
diff --git a/funasr/frontends/whisper_frontend.py b/funasr/frontends/whisper_frontend.py
index 752fd20..acc99af 100644
--- a/funasr/frontends/whisper_frontend.py
+++ b/funasr/frontends/whisper_frontend.py
@@ -1,8 +1,8 @@
from typing import Tuple
import torch
import torch.nn as nn
-import whisper
-from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
+
+
from funasr.register import tables
from torch.nn.utils.rnn import pad_sequence
@@ -17,30 +17,40 @@
def __init__(
self,
fs: int = 16000,
- whisper_model: str = "large-v3",
+ whisper_model: str = None,
do_pad_trim: bool = True,
+ n_mels: int = 80,
+ permute: bool = False,
+ **kwargs,
):
super().__init__()
assert fs == 16000
self.fs = fs
-
+ import whisper
+ from whisper.audio import HOP_LENGTH, N_FFT, N_SAMPLES
self.n_fft = N_FFT
self.win_length = N_FFT
self.hop_length = HOP_LENGTH
self.pad_samples = N_SAMPLES
self.frame_shift = self.hop_length
self.lfr_n = 1
+ self.n_mels = n_mels
if whisper_model == "large-v3" or whisper_model == "large":
self.n_mels = 128
- else:
- self.n_mels = 80
- self.mel_filters = whisper.audio.mel_filters
+ filters_path = kwargs.get("filters_path", None)
+ self.filters_path = filters_path
+ if filters_path is not None:
+ from funasr.models.sense_voice.whisper_lib.audio import mel_filters
+ self.mel_filters = mel_filters
+ else:
+ self.mel_filters = whisper.audio.mel_filters
self.do_pad_trim = do_pad_trim
if do_pad_trim:
self.pad_or_trim = whisper.pad_or_trim
+ self.permute = permute
- assert whisper_model in whisper.available_models()
+ # assert whisper_model in whisper.available_models()
def output_size(self) -> int:
return self.n_mels
@@ -57,8 +67,10 @@
# whisper deletes the last frame by default (Shih-Lun)
magnitudes = stft[..., :-1].abs() ** 2
-
- filters = self.mel_filters(audio.device, self.n_mels)
+ if self.filters_path is not None:
+ filters = self.mel_filters(audio.device, self.n_mels, self.filters_path)
+ else:
+ filters = self.mel_filters(audio.device, self.n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
@@ -77,11 +89,12 @@
return log_spec, olens
def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
+ self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
+ input = input.to(torch.float32)
for i in range(batch_size):
if self.do_pad_trim:
feat = self.pad_or_trim(input[i], self.pad_samples)
@@ -98,5 +111,6 @@
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
-
+ if self.permute:
+ feats_pad = feats_pad.permute(0, 2, 1)
return feats_pad, feats_lens
\ No newline at end of file
--
Gitblit v1.9.1