From a557a55b8bdd2923f1b4a9b3e4e0ff402cc05aeb Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 20 六月 2023 10:25:53 +0800
Subject: [PATCH] update funasr-wss-client funasr-wss-server

---
 funasr/bin/asr_infer.py |   16 +++++++++-------
 1 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index e0e2c09..c722ebc 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -316,7 +316,7 @@
         # 2. Build Language model
         if lm_train_config is not None:
             lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, device
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             scorers["lm"] = lm.lm
 
@@ -636,7 +636,7 @@
         # 2. Build Language model
         if lm_train_config is not None:
             lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, device
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             scorers["lm"] = lm.lm
 
@@ -1120,7 +1120,7 @@
         # 2. Build Language model
         if lm_train_config is not None:
             lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, device
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             lm.to(device)
             scorers["lm"] = lm.lm
@@ -1343,7 +1343,7 @@
 
         if lm_train_config is not None:
             lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, device
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             lm_scorer = lm.lm
         else:
@@ -1636,8 +1636,10 @@
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            if asr_train_args.frontend == 'wav_frontend':
-                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            from funasr.tasks.sa_asr import frontend_choices
+            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
             else:
                 frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                 frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -1659,7 +1661,7 @@
         # 2. Build Language model
         if lm_train_config is not None:
             lm, lm_train_args = build_model_from_file(
-                lm_train_config, lm_file, None, device
+                lm_train_config, lm_file, None, device, task_name="lm"
             )
             scorers["lm"] = lm.lm
 

--
Gitblit v1.9.1