aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
funasr/models/e2e_transducer_unified.py
File was renamed from funasr/models_transducer/espnet_transducer_model_unified.py
@@ -10,10 +10,10 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models_transducer.utils import get_transducer_task_io
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -23,7 +23,7 @@
from funasr.losses.label_smoothing_loss import (  # noqa: H301
    LabelSmoothingLoss,
)
from funasr.models_transducer.error_calculator import ErrorCalculator
from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
@@ -33,7 +33,7 @@
        yield
class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
class UnifiedTransducerModel(AbsESPnetModel):
    """ESPnet2ASRTransducerModel module definition.
    Args:
@@ -289,7 +289,6 @@
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def collect_feats(