From 289cb1d2c8d2fc5a54e9b0fb07b2c33800408d42 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 19 六月 2023 17:52:05 +0800
Subject: [PATCH] update repo

---
 funasr/build_utils/build_asr_model.py |   22 +++++++++++++++-------
 1 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index e5bed1d..d4a954c 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -6,6 +6,7 @@
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
 from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
 from funasr.models.decoder.transformer_decoder import (
     DynamicConvolution2DTransformerDecoder,  # noqa: H301
@@ -19,14 +20,14 @@
 )
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.models.e2e_asr import ASRModel
+from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_asr_mfcca import MFCCA
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, \
+    ContextualParaformer
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_uni_asr import UniASR
-from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
 from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
 from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
@@ -39,6 +40,7 @@
 from funasr.models.frontend.s3prl import S3prlFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAugLFR
@@ -86,6 +88,7 @@
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,
+        neatcontextual_paraformer=NeatContextualParaformer,
         mfcca=MFCCA,
         timestamp_prediction=TimestampPredictor,
         rnnt=TransducerModel,
@@ -238,6 +241,7 @@
         vocab_size = len(token_list)
         logging.info(f"Vocabulary size: {vocab_size}")
     else:
+        token_list = None
         vocab_size = None
 
     # frontend
@@ -252,7 +256,7 @@
         args.frontend = None
         args.frontend_conf = {}
         frontend = None
-        input_size = args.input_size
+        input_size = args.input_size if hasattr(args, "input_size") else None
 
     # data augmentation for spectrogram
     if args.specaug is not None:
@@ -264,7 +268,10 @@
     # normalization layer
     if args.normalize is not None:
         normalize_class = normalize_choices.get_class(args.normalize)
-        normalize = normalize_class(**args.normalize_conf)
+        if args.model == "mfcca":
+            normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
+        else:
+            normalize = normalize_class(**args.normalize_conf)
     else:
         normalize = None
 
@@ -298,7 +305,8 @@
             token_list=token_list,
             **args.model_conf,
         )
-    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
+                        "contextual_paraformer", "neatcontextual_paraformer"]:
         # predictor
         predictor_class = predictor_choices.get_class(args.predictor)
         predictor = predictor_class(**args.predictor_conf)

--
Gitblit v1.9.1