From 594b79f59e7eefa6955c729f6264c8c99d1d9571 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 05 六月 2023 16:46:15 +0800
Subject: [PATCH] Merge pull request #591 from alibaba-damo-academy/dev_lhn
---
funasr/tasks/asr.py | 64 ++++++++++++++-----------------
1 files changed, 29 insertions(+), 35 deletions(-)
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 0bb0563..8244856 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -132,6 +132,8 @@
neatcontextual_paraformer=NeatContextualParaformer,
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
+ rnnt=TransducerModel,
+ rnnt_unified=UnifiedTransducerModel,
),
type_check=FunASRModel,
default="asr",
@@ -222,6 +224,15 @@
),
type_check=RNNTDecoder,
default="rnnt",
+)
+
+joint_network_choices = ClassChoices(
+ name="joint_network",
+ classes=dict(
+ joint_network=JointNetwork,
+ ),
+ default="joint_network",
+ optional=True,
)
predictor_choices = ClassChoices(
@@ -351,12 +362,6 @@
action=NestedDictAction,
default=get_default_kwargs(CTC),
help="The keyword arguments for CTC class.",
- )
- group.add_argument(
- "--joint_net_conf",
- action=NestedDictAction,
- default=None,
- help="The keyword arguments for joint network class.",
)
group = parser.add_argument_group(description="Preprocess related")
@@ -1368,6 +1373,7 @@
num_optimizers: int = 1
class_choices_list = [
+ model_choices,
frontend_choices,
specaug_choices,
normalize_choices,
@@ -1444,7 +1450,7 @@
decoder_output_size = decoder.output_size
if getattr(args, "decoder", None) is not None:
- att_decoder_class = decoder_choices.get_class(args.att_decoder)
+ att_decoder_class = decoder_choices.get_class(args.decoder)
att_decoder = att_decoder_class(
vocab_size=vocab_size,
@@ -1462,35 +1468,23 @@
)
# 7. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("rnnt_unified")
- if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
- model = UnifiedTransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
- else:
- model = TransducerModel(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- encoder=encoder,
- decoder=decoder,
- att_decoder=att_decoder,
- joint_network=joint_network,
- **args.model_conf,
- )
-
+ model = model_class(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
# 8. Initialize model
if args.init is not None:
raise NotImplementedError(
--
Gitblit v1.9.1