From 200d1ede05e6bc41ef1da6debf7b86df84995fb5 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 15:56:25 +0800
Subject: [PATCH] update
---
funasr/utils/build_asr_model.py | 106 +++++++++++++++++++++++++++++++++++++++++------------
1 files changed, 82 insertions(+), 24 deletions(-)
diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py
index f8baa47..2da050c 100644
--- a/funasr/utils/build_asr_model.py
+++ b/funasr/utils/build_asr_model.py
@@ -40,6 +40,7 @@
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
frontend_choices = ClassChoices(
@@ -171,29 +172,30 @@
optional=True,
)
class_choices_list = [
- # --frontend and --frontend_conf
- frontend_choices,
- # --specaug and --specaug_conf
- specaug_choices,
- # --normalize and --normalize_conf
- normalize_choices,
- # --model and --model_conf
- model_choices,
- # --encoder and --encoder_conf
- encoder_choices,
- # --decoder and --decoder_conf
- decoder_choices,
- # --predictor and --predictor_conf
- predictor_choices,
- # --encoder2 and --encoder2_conf
- encoder_choices2,
- # --decoder2 and --decoder2_conf
- decoder_choices2,
- # --predictor2 and --predictor2_conf
- predictor_choices2,
- # --stride_conv and --stride_conv_conf
- stride_conv_choices,
- ]
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ # --predictor and --predictor_conf
+ predictor_choices,
+ # --encoder2 and --encoder2_conf
+ encoder_choices2,
+ # --decoder2 and --decoder2_conf
+ decoder_choices2,
+ # --predictor2 and --predictor2_conf
+ predictor_choices2,
+ # --stride_conv and --stride_conv_conf
+ stride_conv_choices,
+]
+
def build_asr_model(args):
# token_list
@@ -270,6 +272,7 @@
# predictor
predictor_class = predictor_choices.get_class(args.predictor)
predictor = predictor_class(**args.predictor_conf)
+
model_class = model_choices.get_class(args.model)
model = model_class(
vocab_size=vocab_size,
@@ -283,4 +286,59 @@
predictor=predictor,
**args.model_conf,
)
- elif
+ elif args.model == "uniasr":
+ # stride_conv
+ stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
+ stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
+ odim=input_size + encoder.output_size())
+ stride_conv_output_size = stride_conv.output_size()
+
+ # encoder2
+ encoder_class2 = encoder_choices2.get_class(args.encoder2)
+ encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
+
+ # decoder2
+ decoder_class2 = decoder_choices2.get_class(args.decoder2)
+ decoder2 = decoder_class2(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder2.output_size(),
+ **args.decoder2_conf,
+ )
+
+ # ctc2
+ ctc2 = CTC(
+ odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
+ )
+
+ # predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
+ # predictor2
+ predictor_class = predictor_choices2.get_class(args.predictor2)
+ predictor2 = predictor_class(**args.predictor2_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ predictor=predictor,
+ ctc2=ctc2,
+ encoder2=encoder2,
+ decoder2=decoder2,
+ predictor2=predictor2,
+ stride_conv=stride_conv,
+ **args.model_conf,
+ )
+
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
+
+ if args.init is not None:
+ initialize(model, args.init)
--
Gitblit v1.9.1