游雁
2023-12-19 0e622e694e6cb4459955f1e5942a7c53349ce640
funasr/models/bat/model.py
@@ -5,34 +5,23 @@
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from packaging.version import parse as V
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.models.transformer.utils.nets_utils import th_accuracy
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.add_sos_eos import add_sos_eos
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    @contextmanager
    def autocast(enabled=True):
        yield
from torch.cuda.amp import autocast
class BATModel(FunASRModel):
class BATModel(nn.Module):
    """BATModel module definition.
    Args:
@@ -61,18 +50,7 @@
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: AbsEncoder,
        decoder: RNNTDecoder,
        joint_network: JointNetwork,
        att_decoder: Optional[AbsAttDecoder] = None,
        predictor = None,
        transducer_weight: float = 1.0,
        predictor_weight: float = 1.0,
        cif_weight: float = 1.0,
        fastemit_lambda: float = 0.0,
        auxiliary_ctc_weight: float = 0.0,
@@ -89,6 +67,7 @@
        length_normalized_loss: bool = False,
        r_d: int = 5,
        r_u: int = 5,
        **kwargs,
    ) -> None:
        """Construct an BATModel object."""
        super().__init__()