From be3ade87488f70104f5be71d891c0c8500ffdedd Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 10 二月 2023 19:07:15 +0800
Subject: [PATCH] add sond model

---
 funasr/models/frontend/wav_frontend.py |   12 +++++++++---
 1 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 57c5976..7a6425b 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -90,7 +90,9 @@
         filter_length_max: int = -1,
         lfr_m: int = 1,
         lfr_n: int = 1,
-        dither: float = 1.0
+        dither: float = 1.0,
+        snip_edges: bool = True,
+        upsacle_samples: bool = True,
     ):
         assert check_argument_types()
         super().__init__()
@@ -105,6 +107,8 @@
         self.lfr_n = lfr_n
         self.cmvn_file = cmvn_file
         self.dither = dither
+        self.snip_edges = snip_edges
+        self.upsacle_samples = upsacle_samples
 
     def output_size(self) -> int:
         return self.n_mels * self.lfr_m
@@ -119,7 +123,8 @@
         for i in range(batch_size):
             waveform_length = input_lengths[i]
             waveform = input[i][:waveform_length]
-            waveform = waveform * (1 << 15)
+            if self.upsacle_samples:
+                waveform = waveform * (1 << 15)
             waveform = waveform.unsqueeze(0)
             mat = kaldi.fbank(waveform,
                               num_mel_bins=self.n_mels,
@@ -128,7 +133,8 @@
                               dither=self.dither,
                               energy_floor=0.0,
                               window_type=self.window,
-                              sample_frequency=self.fs)
+                              sample_frequency=self.fs,
+                              snip_edges=self.snip_edges)
      
             if self.lfr_m != 1 or self.lfr_n != 1:
                 mat = apply_lfr(mat, self.lfr_m, self.lfr_n)

--
Gitblit v1.9.1