From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg
---
funasr/tasks/asr_transducer.py | 41 ++++++++++++-----------------------------
1 files changed, 12 insertions(+), 29 deletions(-)
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index be14455..cae18c1 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
LightweightConvolutionTransformerDecoder,
TransformerDecoder,
)
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
-from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
-from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
-from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
-from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.e2e_transducer import TransducerModel
+from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
+from funasr.models.joint_network import JointNetwork
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
"encoder",
classes=dict(
encoder=Encoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
),
default="encoder",
)
@@ -158,7 +155,7 @@
group.add_argument(
"--model_conf",
action=NestedDictAction,
- default=get_default_kwargs(ESPnetASRTransducerModel),
+ default=get_default_kwargs(TransducerModel),
help="The keyword arguments for the model class.",
)
# group.add_argument(
@@ -354,7 +351,7 @@
return retval
@classmethod
- def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
+ def build_model(cls, args: argparse.Namespace) -> TransducerModel:
"""Required data depending on task mode.
Args:
cls: ASRTransducerTask object.
@@ -440,22 +437,8 @@
# 7. Build model
- if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
- model = UniASRTransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
- elif encoder.unified_model_training:
- model = ESPnetASRUnifiedTransducerModel(
+ if encoder.unified_model_training:
+ model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
@@ -469,7 +452,7 @@
)
else:
- model = ESPnetASRTransducerModel(
+ model = TransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
--
Gitblit v1.9.1