From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky

---
 funasr/tasks/asr_transducer.py |   48 +++++++++++++++++++-----------------------------
 1 files changed, 19 insertions(+), 29 deletions(-)

diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index 3c7a782..cae18c1 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
     LightweightConvolutionTransformerDecoder,
     TransformerDecoder,
 )
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
-from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
-from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
-from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
-from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.e2e_transducer import TransducerModel
+from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
+from funasr.models.joint_network import JointNetwork
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
 from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
         "encoder",
         classes=dict(
                 encoder=Encoder,
-                sanm_chunk_opt=SANMEncoderChunkOpt,
         ),
         default="encoder",
 )
@@ -138,6 +135,12 @@
             help="Integer-string mapper for tokens.",
         )
         group.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
+        group.add_argument(
             "--input_size",
             type=int_or_none,
             default=None,
@@ -152,7 +155,7 @@
         group.add_argument(
             "--model_conf",
             action=NestedDictAction,
-            default=get_default_kwargs(ESPnetASRTransducerModel),
+            default=get_default_kwargs(TransducerModel),
             help="The keyword arguments for the model class.",
         )
         # group.add_argument(
@@ -289,6 +292,7 @@
                 non_linguistic_symbols=args.non_linguistic_symbols,
                 text_cleaner=args.cleaner,
                 g2p_type=args.g2p,
+                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                 rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                 rir_apply_prob=args.rir_apply_prob
                 if hasattr(args, "rir_apply_prob")
@@ -347,7 +351,7 @@
         return retval
 
     @classmethod
-    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
+    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
         """Required data depending on task mode.
         Args:
             cls: ASRTransducerTask object.
@@ -433,22 +437,8 @@
 
         # 7. Build model
 
-        if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
-            model = UniASRTransducerModel(
-                vocab_size=vocab_size,
-                token_list=token_list,
-                frontend=frontend,
-                specaug=specaug,
-                normalize=normalize,
-                encoder=encoder,
-                decoder=decoder,
-                att_decoder=att_decoder,
-                joint_network=joint_network,
-                **args.model_conf,
-            )
-
-        elif encoder.unified_model_training:
-            model = ESPnetASRUnifiedTransducerModel(
+        if encoder.unified_model_training:
+            model = UnifiedTransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,
                 frontend=frontend,
@@ -462,7 +452,7 @@
             )
 
         else:
-            model = ESPnetASRTransducerModel(
+            model = TransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,
                 frontend=frontend,

--
Gitblit v1.9.1