From b8201e02baf4eda3911ddf77d6b85ff526073bca Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 17:43:27 +0800
Subject: [PATCH] update
---
funasr/build_utils/build_model.py | 3 +
funasr/build_utils/build_vad_model.py | 77 ++++++++++++++++++++++++++++++++++++++
2 files changed, 79 insertions(+), 1 deletions(-)
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index 5b1da0c..6029fae 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -1,7 +1,8 @@
from funasr.build_utils.build_asr_model import build_asr_model
-from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_lm_model import build_lm_model
+from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_vad_model import build_vad_model
def build_model(args):
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
new file mode 100644
index 0000000..76eb09b
--- /dev/null
+++ b/funasr/build_utils/build_vad_model.py
@@ -0,0 +1,77 @@
+import torch
+
+from funasr.models.e2e_vad import E2EVadModel
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ wav_frontend_online=WavFrontendOnline,
+ ),
+ default="default",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ fsmn=FSMN,
+ ),
+ type_check=torch.nn.Module,
+ default="fsmn",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ e2evad=E2EVadModel,
+ ),
+ default="e2evad",
+)
+
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --model and --model_conf
+ model_choices,
+]
+
+
+def build_vad_model(args):
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(**args.encoder_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
--
Gitblit v1.9.1