嘉渊
2023-04-24 eec914bef61a802a955ea7be4d06284f00efd69a
funasr/models/e2e_tp.py
@@ -2,22 +2,17 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.predictor.cif import CifPredictorV3
@@ -30,15 +25,15 @@
        yield
class TimestampPredictor(AbsESPnetModel):
class TimestampPredictor(FunASRModel):
    """
    Author: Speech Lab, Alibaba Group, China
    """
    def __init__(
            self,
            frontend: Optional[AbsFrontend],
            encoder: AbsEncoder,
            frontend: Optional[torch.nn.Module],
            encoder: torch.nn.Module,
            predictor: CifPredictorV3,
            predictor_bias: int = 0,
            token_list=None,