From 2a66366be4c2715870e4859fd5a5db6e8a9dc00a Mon Sep 17 00:00:00 2001
From: chenmengzheAAA <123789350+chenmengzheAAA@users.noreply.github.com>
Date: 星期四, 14 九月 2023 19:00:17 +0800
Subject: [PATCH] Merge pull request #956 from alibaba-damo-academy/chenmengzheAAA-patch-4

---
 funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py |   21 +++++++++++----------
 1 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
index 6464964..c556daf 100644
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
+++ b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
@@ -105,11 +105,10 @@
             frame_shift: int = 10,
             filter_length_min: int = -1,
             filter_length_max: float = -1,
-            lfr_m: int = 1,
-            lfr_n: int = 1,
+            lfr_m: int = 7,
+            lfr_n: int = 6,
             dither: float = 1.0
     ) -> None:
-        # check_argument_types()
 
         self.fs = fs
         self.window = window
@@ -229,22 +228,24 @@
             if key == "config_path":
                 with open(str(value), 'rb') as f:
                     config = yaml.load(f, Loader=yaml.Loader)
+            if key == "cmvn_path":
+                cmvn_path = str(value)
 
         opts = kaldifeat.FbankOptions()
         opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
-        opts.frame_opts.window_type = config['WavFrontend']['frontend_conf']['window']
-        opts.mel_opts.num_bins = int(config['WavFrontend']['frontend_conf']['n_mels'])
-        opts.frame_opts.frame_shift_ms = float(config['WavFrontend']['frontend_conf']['frame_shift'])
-        opts.frame_opts.frame_length_ms = float(config['WavFrontend']['frontend_conf']['frame_length'])
-        opts.frame_opts.samp_freq = int(config['WavFrontend']['frontend_conf']['fs'])
+        opts.frame_opts.window_type = config['frontend_conf']['window']
+        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
+        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
+        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
+        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
         opts.device = torch.device(self.device)
         self.opts = opts
         self.feature_extractor = Fbank(self.opts)
         self.feature_size = opts.mel_opts.num_bins
 
         self.frontend = WavFrontend(
-            cmvn_file=config['WavFrontend']['cmvn_file'],
-            **config['WavFrontend']['frontend_conf'])
+            cmvn_file=cmvn_path,
+            **config['frontend_conf'])
 
     def extract_feat(self,
                      waveform_list: List[np.ndarray]

--
Gitblit v1.9.1