From be3ade87488f70104f5be71d891c0c8500ffdedd Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 10 二月 2023 19:07:15 +0800
Subject: [PATCH] add sond model
---
funasr/models/frontend/wav_frontend.py | 12 +++++++++---
1 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 57c5976..7a6425b 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -90,7 +90,9 @@
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
- dither: float = 1.0
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
@@ -105,6 +107,8 @@
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
+ self.snip_edges = snip_edges
+ self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -119,7 +123,8 @@
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
- waveform = waveform * (1 << 15)
+ if self.upsacle_samples:
+ waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
@@ -128,7 +133,8 @@
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
- sample_frequency=self.fs)
+ sample_frequency=self.fs,
+ snip_edges=self.snip_edges)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
--
Gitblit v1.9.1