From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/build_utils/build_asr_model.py |  107 +++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 91 insertions(+), 16 deletions(-)

diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index 7aa8111..5e93444 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -20,21 +20,27 @@
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
 from funasr.models.e2e_asr import ASRModel
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_asr_mfcca import MFCCA
+
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+from funasr.models.e2e_asr_bat import BATModel
+
 from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_uni_asr import UniASR
-from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
 from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
 from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
 from funasr.models.encoder.resnet34_encoder import ResNet34Diar
 from funasr.models.encoder.rnn_encoder import RNNEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
+from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
 from funasr.models.encoder.transformer_encoder import TransformerEncoder
 from funasr.models.frontend.default import DefaultFrontend
 from funasr.models.frontend.default import MultiChannelFrontend
@@ -42,7 +48,8 @@
 from funasr.models.frontend.s3prl import S3prlFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAugLFR
 from funasr.modules.subsampling import Conv1dSubsampling
@@ -89,12 +96,13 @@
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,
+        neatcontextual_paraformer=NeatContextualParaformer,
         mfcca=MFCCA,
         timestamp_prediction=TimestampPredictor,
         rnnt=TransducerModel,
         rnnt_unified=UnifiedTransducerModel,
         sa_asr=SAASRModel,
-
+        bat=BATModel,
     ),
     default="asr",
 )
@@ -107,6 +115,8 @@
         sanm=SANMEncoder,
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
+        branchformer=BranchformerEncoder,
+        e_branchformer=EBranchformerEncoder,
         mfcca_enc=MFCCAEncoder,
         chunk_conformer=ConformerChunkEncoder,
     ),
@@ -183,6 +193,7 @@
         ctc_predictor=None,
         cif_predictor_v2=CifPredictorV2,
         cif_predictor_v3=CifPredictorV3,
+        bat_predictor=BATPredictor,
     ),
     default="cif_predictor",
     optional=True,
@@ -258,17 +269,22 @@
 
 def build_asr_model(args):
     # token_list
-    if args.token_list is not None:
-        with open(args.token_list) as f:
+    if isinstance(args.token_list, str):
+        with open(args.token_list, encoding="utf-8") as f:
             token_list = [line.rstrip() for line in f]
         args.token_list = list(token_list)
         vocab_size = len(token_list)
         logging.info(f"Vocabulary size: {vocab_size}")
+    elif isinstance(args.token_list, (tuple, list)):
+        token_list = list(args.token_list)
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
     else:
+        token_list = None
         vocab_size = None
 
     # frontend
-    if args.input_size is None:
+    if hasattr(args, "input_size") and args.input_size is None:
         frontend_class = frontend_choices.get_class(args.frontend)
         if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
             frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -279,7 +295,7 @@
         args.frontend = None
         args.frontend_conf = {}
         frontend = None
-        input_size = args.input_size
+        input_size = args.input_size if hasattr(args, "input_size") else None
 
     # data augmentation for spectrogram
     if args.specaug is not None:
@@ -291,7 +307,10 @@
     # normalization layer
     if args.normalize is not None:
         normalize_class = normalize_choices.get_class(args.normalize)
-        normalize = normalize_class(**args.normalize_conf)
+        if args.model == "mfcca":
+            normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
+        else:
+            normalize = normalize_class(**args.normalize_conf)
     else:
         normalize = None
 
@@ -300,12 +319,15 @@
     encoder = encoder_class(input_size=input_size, **args.encoder_conf)
 
     # decoder
-    decoder_class = decoder_choices.get_class(args.decoder)
-    decoder = decoder_class(
-        vocab_size=vocab_size,
-        encoder_output_size=encoder.output_size(),
-        **args.decoder_conf,
-    )
+    if hasattr(args, "decoder") and args.decoder is not None:
+        decoder_class = decoder_choices.get_class(args.decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder.output_size(),
+            **args.decoder_conf,
+        )
+    else:
+        decoder = None
 
     # ctc
     ctc = CTC(
@@ -325,7 +347,8 @@
             token_list=token_list,
             **args.model_conf,
         )
-    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
+                        "contextual_paraformer", "neatcontextual_paraformer"]:
         # predictor
         predictor_class = predictor_choices.get_class(args.predictor)
         predictor = predictor_class(**args.predictor_conf)
@@ -394,10 +417,15 @@
             **args.model_conf,
         )
     elif args.model == "timestamp_prediction":
+        # predictor
+        predictor_class = predictor_choices.get_class(args.predictor)
+        predictor = predictor_class(**args.predictor_conf)
+        
         model_class = model_choices.get_class(args.model)
         model = model_class(
             frontend=frontend,
             encoder=encoder,
+            predictor=predictor,
             token_list=token_list,
             **args.model_conf,
         )
@@ -444,6 +472,53 @@
             joint_network=joint_network,
             **args.model_conf,
         )
+    elif args.model == "bat":
+        # 5. Decoder
+        encoder_output_size = encoder.output_size()
+
+        rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+        decoder = rnnt_decoder_class(
+            vocab_size,
+            **args.rnnt_decoder_conf,
+        )
+        decoder_output_size = decoder.output_size
+
+        if getattr(args, "decoder", None) is not None:
+            att_decoder_class = decoder_choices.get_class(args.decoder)
+
+            att_decoder = att_decoder_class(
+                vocab_size=vocab_size,
+                encoder_output_size=encoder_output_size,
+                **args.decoder_conf,
+            )
+        else:
+            att_decoder = None
+        # 6. Joint Network
+        joint_network = JointNetwork(
+            vocab_size,
+            encoder_output_size,
+            decoder_output_size,
+            **args.joint_network_conf,
+        )
+
+        predictor_class = predictor_choices.get_class(args.predictor)
+        predictor = predictor_class(**args.predictor_conf)
+
+        model_class = model_choices.get_class(args.model)
+        # 7. Build model
+        model = model_class(
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+            decoder=decoder,
+            att_decoder=att_decoder,
+            joint_network=joint_network,
+            predictor=predictor,
+            **args.model_conf,
+        )
     elif args.model == "sa_asr":
         asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
         asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)

--
Gitblit v1.9.1