From 18b1449d1ff06c469e54190508c4f6be05c73d85 Mon Sep 17 00:00:00 2001
From: 夜雨飘零 <yeyupiaoling@foxmail.com>
Date: 星期二, 05 十二月 2023 22:04:14 +0800
Subject: [PATCH] 分角色语音识别支持更多的模型

---
 funasr/bin/asr_inference_launch.py |   25 +++++++++++++------------
 1 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 402a911..59e61ee 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -51,10 +51,10 @@
 from funasr.utils.speaker_utils import (check_audio_list,
                                         sv_preprocess,
                                         sv_chunk,
-                                        CAMPPlus,
                                         extract_feature,
                                         postprocess,
-                                        distribute_spk, ERes2Net)
+                                        distribute_spk)
+import funasr.modules.cnn as sv_module
 from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.utils.cluster_backend import ClusterBackend
 from funasr.utils.modelscope_utils import get_cache_dir
@@ -818,11 +818,15 @@
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
 
-    sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin")
-    if not os.path.exists(sv_model_file):
-        sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt")
-        if not os.path.exists(sv_model_file):
-            raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file))
+    sv_model_config_path = asr_model_file.replace("model.pb", "sv_model_config.yaml")
+    if not os.path.exists(sv_model_config_path):
+        sv_model_config = {'sv_model_class': 'CAMPPlus','sv_model_file': 'campplus_cn_common.bin', 'models_config': {}}
+    else:
+        with open(sv_model_config_path, 'r') as f:
+            sv_model_config = yaml.load(f, Loader=yaml.FullLoader)
+    if sv_model_config['models_config'] is None:
+        sv_model_config['models_config'] = {}
+    sv_model_file = asr_model_file.replace("model.pb", sv_model_config['sv_model_file'])
 
     if param_dict is not None:
         hotword_list_or_file = param_dict.get('hotword')
@@ -949,14 +953,11 @@
             ##################################
             # load sv model
             sv_model_dict = torch.load(sv_model_file)
-            print(f'load sv model params: {sv_model_file}')
-            if os.path.basename(sv_model_file) == "campplus_cn_common.bin":
-                sv_model = CAMPPlus()
-            else:
-                sv_model = ERes2Net()
+            sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
             if ngpu > 0:
                 sv_model.cuda()
             sv_model.load_state_dict(sv_model_dict)
+            print(f'load sv model params: {sv_model_file}')
             sv_model.eval()
             cb_model = ClusterBackend()
             vad_segments = []

--
Gitblit v1.9.1