游雁
2023-05-19 219c2482ab755fbd4e49dfbdee91bf1a8a4ec49a
funasr/models/e2e_sa_asr.py
@@ -16,9 +16,8 @@
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
    LabelSmoothingLoss, NllLoss  # noqa: H301
)
from funasr.losses.nll_loss import NllLoss
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -30,7 +29,7 @@
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -41,7 +40,7 @@
        yield
class ESPnetASRModel(AbsESPnetModel):
class ESPnetASRModel(FunASRModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(