From 200d1ede05e6bc41ef1da6debf7b86df84995fb5 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 15:56:25 +0800
Subject: [PATCH] update

---
 funasr/utils/build_asr_model.py |  106 +++++++++++++++++++++++++++++++++++++++++------------
 1 files changed, 82 insertions(+), 24 deletions(-)

diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py
index f8baa47..2da050c 100644
--- a/funasr/utils/build_asr_model.py
+++ b/funasr/utils/build_asr_model.py
@@ -40,6 +40,7 @@
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAugLFR
 from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.torch_utils.initialize import initialize
 from funasr.train.class_choices import ClassChoices
 
 frontend_choices = ClassChoices(
@@ -171,29 +172,30 @@
     optional=True,
 )
 class_choices_list = [
-        # --frontend and --frontend_conf
-        frontend_choices,
-        # --specaug and --specaug_conf
-        specaug_choices,
-        # --normalize and --normalize_conf
-        normalize_choices,
-        # --model and --model_conf
-        model_choices,
-        # --encoder and --encoder_conf
-        encoder_choices,
-        # --decoder and --decoder_conf
-        decoder_choices,
-        # --predictor and --predictor_conf
-        predictor_choices,
-        # --encoder2 and --encoder2_conf
-        encoder_choices2,
-        # --decoder2 and --decoder2_conf
-        decoder_choices2,
-        # --predictor2 and --predictor2_conf
-        predictor_choices2,
-        # --stride_conv and --stride_conv_conf
-        stride_conv_choices,
-    ]
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --specaug and --specaug_conf
+    specaug_choices,
+    # --normalize and --normalize_conf
+    normalize_choices,
+    # --model and --model_conf
+    model_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --decoder and --decoder_conf
+    decoder_choices,
+    # --predictor and --predictor_conf
+    predictor_choices,
+    # --encoder2 and --encoder2_conf
+    encoder_choices2,
+    # --decoder2 and --decoder2_conf
+    decoder_choices2,
+    # --predictor2 and --predictor2_conf
+    predictor_choices2,
+    # --stride_conv and --stride_conv_conf
+    stride_conv_choices,
+]
+
 
 def build_asr_model(args):
     # token_list
@@ -270,6 +272,7 @@
         # predictor
         predictor_class = predictor_choices.get_class(args.predictor)
         predictor = predictor_class(**args.predictor_conf)
+
         model_class = model_choices.get_class(args.model)
         model = model_class(
             vocab_size=vocab_size,
@@ -283,4 +286,59 @@
             predictor=predictor,
             **args.model_conf,
         )
-    elif
+    elif args.model == "uniasr":
+        # stride_conv
+        stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
+        stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
+                                        odim=input_size + encoder.output_size())
+        stride_conv_output_size = stride_conv.output_size()
+
+        # encoder2
+        encoder_class2 = encoder_choices2.get_class(args.encoder2)
+        encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
+
+        # decoder2
+        decoder_class2 = decoder_choices2.get_class(args.decoder2)
+        decoder2 = decoder_class2(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder2.output_size(),
+            **args.decoder2_conf,
+        )
+
+        # ctc2
+        ctc2 = CTC(
+            odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
+        )
+
+        # predictor
+        predictor_class = predictor_choices.get_class(args.predictor)
+        predictor = predictor_class(**args.predictor_conf)
+
+        # predictor2
+        predictor_class = predictor_choices2.get_class(args.predictor2)
+        predictor2 = predictor_class(**args.predictor2_conf)
+
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+            decoder=decoder,
+            ctc=ctc,
+            token_list=token_list,
+            predictor=predictor,
+            ctc2=ctc2,
+            encoder2=encoder2,
+            decoder2=decoder2,
+            predictor2=predictor2,
+            stride_conv=stride_conv,
+            **args.model_conf,
+        )
+
+    else:
+        raise NotImplementedError("Not supported model: {}".format(args.model))
+
+    if args.init is not None:
+        initialize(model, args.init)

--
Gitblit v1.9.1