From 5a637c6995f80a41be7026214f4f5a76ad07db70 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 三月 2024 11:45:09 +0800
Subject: [PATCH] vad conf

---
 examples/industrial_data_pretraining/emotion2vec/demo.py |    6 +++++-
 funasr/auto/auto_model.py                                |    8 +++-----
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index 227f504..b274bd9 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -6,7 +6,11 @@
 from funasr import AutoModel
 
 # model="iic/emotion2vec_base"
-model = AutoModel(model="iic/emotion2vec_base_finetuned", model_revision="v2.0.4")
+model = AutoModel(model="iic/emotion2vec_base_finetuned", model_revision="v2.0.4",
+                  # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                  # vad_model_revision="v2.0.4",
+                  # vad_kwargs={"max_single_segment_time": 10},
+                  )
 
 wav_file = f"{model.model_path}/example/test.wav"
 res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 39f91e9..d044455 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -106,9 +106,9 @@
         
         # if vad_model is not None, build vad model else None
         vad_model = kwargs.get("vad_model", None)
+        vad_kwargs = {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
         if vad_model is not None:
             logging.info("Building VAD model.")
-            vad_kwargs = {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
             vad_kwargs["model"] = vad_model
             vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", None)
             vad_kwargs["device"] = kwargs["device"]
@@ -116,10 +116,9 @@
 
         # if punc_model is not None, build punc model else None
         punc_model = kwargs.get("punc_model", None)
-        
+        punc_kwargs = {} if kwargs.get("punc_kwargs", {}) is None else kwargs.get("punc_kwargs", {})
         if punc_model is not None:
             logging.info("Building punc model.")
-            punc_kwargs = {} if kwargs.get("punc_kwargs", {}) is None else kwargs.get("punc_kwargs", {})
             punc_kwargs["model"] = punc_model
             punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", None)
             punc_kwargs["device"] = kwargs["device"]
@@ -127,10 +126,9 @@
 
         # if spk_model is not None, build spk model else None
         spk_model = kwargs.get("spk_model", None)
-        spk_kwargs = kwargs.get("spk_model_revision", None)
+        spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
         if spk_model is not None:
             logging.info("Building SPK model.")
-            spk_kwargs = {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
             spk_kwargs["model"] = spk_model
             spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", None)
             spk_kwargs["device"] = kwargs["device"]

--
Gitblit v1.9.1