雾聪
2024-01-15 8a8d60d5786510ec7b1dd4f622e848de8a15f8a8
funasr/models/bat/model.py
@@ -1,38 +1,31 @@
"""Boundary Aware Transducer (BAT) model."""
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
import logging
import torch.nn as nn
from typing import Dict, List, Optional, Tuple, Union
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.models.transformer.utils.nets_utils import th_accuracy
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.add_sos_eos import add_sos_eos
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if V(torch.__version__) >= V("1.6.0"):
    from torch.cuda.amp import autocast
else:
    @contextmanager
    def autocast(enabled=True):
        yield
class BATModel(FunASRModel):
class BATModel(nn.Module):
    """BATModel module definition.
    Args:
@@ -61,18 +54,7 @@
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: AbsEncoder,
        decoder: RNNTDecoder,
        joint_network: JointNetwork,
        att_decoder: Optional[AbsAttDecoder] = None,
        predictor = None,
        transducer_weight: float = 1.0,
        predictor_weight: float = 1.0,
        cif_weight: float = 1.0,
        fastemit_lambda: float = 0.0,
        auxiliary_ctc_weight: float = 0.0,
@@ -89,6 +71,7 @@
        length_normalized_loss: bool = False,
        r_d: int = 5,
        r_u: int = 5,
        **kwargs,
    ) -> None:
        """Construct an BATModel object."""
        super().__init__()