From b8201e02baf4eda3911ddf77d6b85ff526073bca Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 17:43:27 +0800
Subject: [PATCH] update

---
 funasr/build_utils/build_model.py     |    3 +
 funasr/build_utils/build_vad_model.py |   77 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 79 insertions(+), 1 deletions(-)

diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index 5b1da0c..6029fae 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -1,7 +1,8 @@
 from funasr.build_utils.build_asr_model import build_asr_model
-from funasr.build_utils.build_pretrain_model import build_pretrain_model
 from funasr.build_utils.build_lm_model import build_lm_model
+from funasr.build_utils.build_pretrain_model import build_pretrain_model
 from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_vad_model import build_vad_model
 
 
 def build_model(args):
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
new file mode 100644
index 0000000..76eb09b
--- /dev/null
+++ b/funasr/build_utils/build_vad_model.py
@@ -0,0 +1,77 @@
+import torch
+
+from funasr.models.e2e_vad import E2EVadModel
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(
+        default=DefaultFrontend,
+        sliding_window=SlidingWindow,
+        s3prl=S3prlFrontend,
+        fused=FusedFrontends,
+        wav_frontend=WavFrontend,
+        wav_frontend_online=WavFrontendOnline,
+    ),
+    default="default",
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        fsmn=FSMN,
+    ),
+    type_check=torch.nn.Module,
+    default="fsmn",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        e2evad=E2EVadModel,
+    ),
+    default="e2evad",
+)
+
+class_choices_list = [
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --model and --model_conf
+    model_choices,
+]
+
+
+def build_vad_model(args):
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        if args.frontend == 'wav_frontend':
+            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+        else:
+            frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(**args.encoder_conf)
+
+    model_class = model_choices.get_class(args.model)
+    model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
+
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model

--
Gitblit v1.9.1