游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/models/e2e_asr_mfcca.py
@@ -7,7 +7,6 @@
from typing import Union
import logging
import torch
from typeguard import check_argument_types
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
@@ -17,10 +16,13 @@
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.base_model import FunASRModel
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -32,32 +34,36 @@
import pdb
import random
import math
class MFCCA(FunASRModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    """
    Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
    MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
    https://arxiv.org/abs/2210.05265
    """
    def __init__(
        self,
        vocab_size: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[torch.nn.Module],
        specaug: Optional[torch.nn.Module],
        normalize: Optional[torch.nn.Module],
        preencoder: Optional[AbsPreEncoder],
        encoder: torch.nn.Module,
        decoder: AbsDecoder,
        ctc: CTC,
        rnnt_decoder: None,
        ctc_weight: float = 0.5,
        ignore_id: int = -1,
        lsm_weight: float = 0.0,
        mask_ratio: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
            self,
            vocab_size: int,
            token_list: Union[Tuple[str, ...], List[str]],
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            encoder: AbsEncoder,
            decoder: AbsDecoder,
            ctc: CTC,
            rnnt_decoder: None = None,
            ctc_weight: float = 0.5,
            ignore_id: int = -1,
            lsm_weight: float = 0.0,
            mask_ratio: float = 0.0,
            length_normalized_loss: bool = False,
            report_cer: bool = True,
            report_wer: bool = True,
            sym_space: str = "<space>",
            sym_blank: str = "<blank>",
            preencoder: Optional[AbsPreEncoder] = None,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert rnnt_decoder is None, "Not implemented"
@@ -69,10 +75,9 @@
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.token_list = token_list.copy()
        self.mask_ratio = mask_ratio
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
@@ -106,14 +111,13 @@
            self.error_calculator = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
@@ -123,22 +127,22 @@
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0]
            == speech_lengths.shape[0]
            == text.shape[0]
            == text_lengths.shape[0]
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        #pdb.set_trace()
        if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
        # pdb.set_trace()
        if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
            rate_num = random.random()
            #rate_num = 0.1
            if(rate_num<=self.mask_ratio):
                retain_channel = math.ceil(random.random() *8)
                if(retain_channel>1):
                    speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
            # rate_num = 0.1
            if (rate_num <= self.mask_ratio):
                retain_channel = math.ceil(random.random() * 8)
                if (retain_channel > 1):
                    speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
                else:
                    speech = speech[:,:,torch.randperm(8)[0]]
        #pdb.set_trace()
                    speech = speech[:, :, torch.randperm(8)[0]]
        # pdb.set_trace()
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
@@ -188,20 +192,19 @@
        return loss, stats, weight
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
@@ -220,14 +223,14 @@
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
        #pdb.set_trace()
        # pdb.set_trace()
        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
            speech.size(0),
        )
        if(encoder_out.dim()==4):
        if (encoder_out.dim() == 4):
            assert encoder_out.size(2) <= encoder_out_lens.max(), (
                encoder_out.size(),
                encoder_out_lens.max(),
@@ -241,7 +244,7 @@
        return encoder_out, encoder_out_lens
    def _extract_feats(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # for data-parallel
@@ -259,11 +262,11 @@
        return feats, feats_lengths, channel_size
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
            self,
            encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1
@@ -291,14 +294,14 @@
        return loss_att, acc_att, cer_att, wer_att
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
            self,
            encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        if(encoder_out.dim()==4):
        if (encoder_out.dim() == 4):
            encoder_out = encoder_out.mean(1)
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
@@ -310,10 +313,10 @@
        return loss_ctc, cer_ctc
    def _calc_rnnt_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
            self,
            encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
    ):
        raise NotImplementedError
        raise NotImplementedError