From 2cdb2d654f2109ef4e648bae6f169143e267e5db Mon Sep 17 00:00:00 2001
From: zhuzizyf <42790740+zhuzizyf@users.noreply.github.com>
Date: 星期六, 11 三月 2023 14:33:14 +0800
Subject: [PATCH] Update dataset.py

---
 funasr/models/frontend/wav_frontend.py |   17 ++++++++++-------
 1 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 57c5976..ed8cb36 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -90,7 +90,9 @@
         filter_length_max: int = -1,
         lfr_m: int = 1,
         lfr_n: int = 1,
-        dither: float = 1.0
+        dither: float = 1.0,
+        snip_edges: bool = True,
+        upsacle_samples: bool = True,
     ):
         assert check_argument_types()
         super().__init__()
@@ -105,6 +107,8 @@
         self.lfr_n = lfr_n
         self.cmvn_file = cmvn_file
         self.dither = dither
+        self.snip_edges = snip_edges
+        self.upsacle_samples = upsacle_samples
 
     def output_size(self) -> int:
         return self.n_mels * self.lfr_m
@@ -119,7 +123,8 @@
         for i in range(batch_size):
             waveform_length = input_lengths[i]
             waveform = input[i][:waveform_length]
-            waveform = waveform * (1 << 15)
+            if self.upsacle_samples:
+                waveform = waveform * (1 << 15)
             waveform = waveform.unsqueeze(0)
             mat = kaldi.fbank(waveform,
                               num_mel_bins=self.n_mels,
@@ -128,7 +133,8 @@
                               dither=self.dither,
                               energy_floor=0.0,
                               window_type=self.window,
-                              sample_frequency=self.fs)
+                              sample_frequency=self.fs,
+                              snip_edges=self.snip_edges)
      
             if self.lfr_m != 1 or self.lfr_n != 1:
                 mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
@@ -165,10 +171,7 @@
                               window_type=self.window,
                               sample_frequency=self.fs)
 
-            # if self.lfr_m != 1 or self.lfr_n != 1:
-            #     mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-            # if self.cmvn_file is not None:
-            #     mat = apply_cmvn(mat, self.cmvn_file)
+
             feat_length = mat.size(0)
             feats.append(mat)
             feats_lens.append(feat_length)

--
Gitblit v1.9.1