From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg

---
 funasr/tasks/asr_transducer.py |   41 ++++++++++++-----------------------------
 1 files changed, 12 insertions(+), 29 deletions(-)

diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index be14455..cae18c1 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
     LightweightConvolutionTransformerDecoder,
     TransformerDecoder,
 )
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
-from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
-from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
-from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
-from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.e2e_transducer import TransducerModel
+from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
+from funasr.models.joint_network import JointNetwork
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
 from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
         "encoder",
         classes=dict(
                 encoder=Encoder,
-                sanm_chunk_opt=SANMEncoderChunkOpt,
         ),
         default="encoder",
 )
@@ -158,7 +155,7 @@
         group.add_argument(
             "--model_conf",
             action=NestedDictAction,
-            default=get_default_kwargs(ESPnetASRTransducerModel),
+            default=get_default_kwargs(TransducerModel),
             help="The keyword arguments for the model class.",
         )
         # group.add_argument(
@@ -354,7 +351,7 @@
         return retval
 
     @classmethod
-    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
+    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
         """Required data depending on task mode.
         Args:
             cls: ASRTransducerTask object.
@@ -440,22 +437,8 @@
 
         # 7. Build model
 
-        if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
-            model = UniASRTransducerModel(
-                vocab_size=vocab_size,
-                token_list=token_list,
-                frontend=frontend,
-                specaug=specaug,
-                normalize=normalize,
-                encoder=encoder,
-                decoder=decoder,
-                att_decoder=att_decoder,
-                joint_network=joint_network,
-                **args.model_conf,
-            )
-
-        elif encoder.unified_model_training:
-            model = ESPnetASRUnifiedTransducerModel(
+        if encoder.unified_model_training:
+            model = UnifiedTransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,
                 frontend=frontend,
@@ -469,7 +452,7 @@
             )
 
         else:
-            model = ESPnetASRTransducerModel(
+            model = TransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,
                 frontend=frontend,

--
Gitblit v1.9.1