From 2cdb2d654f2109ef4e648bae6f169143e267e5db Mon Sep 17 00:00:00 2001
From: zhuzizyf <42790740+zhuzizyf@users.noreply.github.com>
Date: 星期六, 11 三月 2023 14:33:14 +0800
Subject: [PATCH] Update dataset.py
---
funasr/models/frontend/wav_frontend.py | 17 ++++++++++-------
1 files changed, 10 insertions(+), 7 deletions(-)
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 57c5976..ed8cb36 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -90,7 +90,9 @@
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
- dither: float = 1.0
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
@@ -105,6 +107,8 @@
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
+ self.snip_edges = snip_edges
+ self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@@ -119,7 +123,8 @@
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
- waveform = waveform * (1 << 15)
+ if self.upsacle_samples:
+ waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
@@ -128,7 +133,8 @@
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
- sample_frequency=self.fs)
+ sample_frequency=self.fs,
+ snip_edges=self.snip_edges)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
@@ -165,10 +171,7 @@
window_type=self.window,
sample_frequency=self.fs)
- # if self.lfr_m != 1 or self.lfr_n != 1:
- # mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- # if self.cmvn_file is not None:
- # mat = apply_cmvn(mat, self.cmvn_file)
+
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
--
Gitblit v1.9.1