From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky
---
funasr/tasks/asr_transducer.py | 48 +++++++++++++++++++-----------------------------
1 files changed, 19 insertions(+), 29 deletions(-)
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index 3c7a782..cae18c1 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
LightweightConvolutionTransformerDecoder,
TransformerDecoder,
)
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
-from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
-from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
-from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
-from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.e2e_transducer import TransducerModel
+from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
+from funasr.models.joint_network import JointNetwork
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
"encoder",
classes=dict(
encoder=Encoder,
- sanm_chunk_opt=SANMEncoderChunkOpt,
),
default="encoder",
)
@@ -138,6 +135,12 @@
help="Integer-string mapper for tokens.",
)
group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
"--input_size",
type=int_or_none,
default=None,
@@ -152,7 +155,7 @@
group.add_argument(
"--model_conf",
action=NestedDictAction,
- default=get_default_kwargs(ESPnetASRTransducerModel),
+ default=get_default_kwargs(TransducerModel),
help="The keyword arguments for the model class.",
)
# group.add_argument(
@@ -289,6 +292,7 @@
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
@@ -347,7 +351,7 @@
return retval
@classmethod
- def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
+ def build_model(cls, args: argparse.Namespace) -> TransducerModel:
"""Required data depending on task mode.
Args:
cls: ASRTransducerTask object.
@@ -433,22 +437,8 @@
# 7. Build model
- if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
- model = UniASRTransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
- elif encoder.unified_model_training:
- model = ESPnetASRUnifiedTransducerModel(
+ if encoder.unified_model_training:
+ model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
@@ -462,7 +452,7 @@
)
else:
- model = ESPnetASRTransducerModel(
+ model = TransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
--
Gitblit v1.9.1