shixian.shi
2024-01-10 e30a17cf4e715b3d139fa1e0ba01cda1bcf0f884
funasr/models/uniasr/e2e_uni_asr.py
@@ -10,15 +10,15 @@
import torch
from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.models.transformer.utils.nets_utils import th_accuracy
from funasr.models.transformer.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
@@ -26,7 +26,7 @@
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.scama.chunk_utilis import sequence_mask
from funasr.models.predictor.cif import mae_loss
from funasr.models.paraformer.cif_predictor import mae_loss
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast