nichongjia-2007
2023-05-31 cc2c1d1d53dea5d2c45f858d1baa5bd279f47987
funasr/models/e2e_asr_mfcca.py
@@ -23,7 +23,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -35,7 +35,8 @@
import pdb
import random
import math
class MFCCA(AbsESPnetModel):
class MFCCA(FunASRModel):
    """
    Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
    MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
@@ -49,7 +50,6 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        preencoder: Optional[AbsPreEncoder],
        encoder: AbsEncoder,
        decoder: AbsDecoder,
        ctc: CTC,
@@ -63,6 +63,7 @@
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
            preencoder: Optional[AbsPreEncoder] = None,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -78,7 +79,6 @@
        self.token_list = token_list.copy()
        
        self.mask_ratio = mask_ratio
        
        self.frontend = frontend
        self.specaug = specaug
@@ -120,7 +120,6 @@
        text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
@@ -208,7 +207,6 @@
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )