From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg
---
funasr/models/e2e_transducer_unified.py | 13 ++++++-------
1 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models/e2e_transducer_unified.py
similarity index 97%
rename from funasr/models_transducer/espnet_transducer_model_unified.py
rename to funasr/models/e2e_transducer_unified.py
index be61e83..6003542 100644
--- a/funasr/models_transducer/espnet_transducer_model_unified.py
+++ b/funasr/models/e2e_transducer_unified.py
@@ -10,10 +10,10 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.joint_network import JointNetwork
-from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -23,7 +23,7 @@
from funasr.losses.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
-from funasr.models_transducer.error_calculator import ErrorCalculator
+from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
@@ -33,7 +33,7 @@
yield
-class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
+class UnifiedTransducerModel(AbsESPnetModel):
"""ESPnet2ASRTransducerModel module definition.
Args:
@@ -289,7 +289,6 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-
return loss, stats, weight
def collect_feats(
--
Gitblit v1.9.1