From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2

---
 funasr/models/e2e_asr_mfcca.py |  138 ++++++++++++++++++++++-----------------------
 1 files changed, 68 insertions(+), 70 deletions(-)

diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index f22f12a..fbf0d11 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -23,7 +23,7 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -35,7 +35,8 @@
 import pdb
 import random
 import math
-class MFCCA(AbsESPnetModel):
+
+class MFCCA(FunASRModel):
     """
     Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
     MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
@@ -43,26 +44,26 @@
     """
 
     def __init__(
-        self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
-        preencoder: Optional[AbsPreEncoder],
-        encoder: AbsEncoder,
-        decoder: AbsDecoder,
-        ctc: CTC,
-        rnnt_decoder: None,
-        ctc_weight: float = 0.5,
-        ignore_id: int = -1,
-        lsm_weight: float = 0.0,
-        mask_ratio: float = 0.0,
-        length_normalized_loss: bool = False,
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
+            self,
+            vocab_size: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            encoder: AbsEncoder,
+            decoder: AbsDecoder,
+            ctc: CTC,
+            rnnt_decoder: None,
+            ctc_weight: float = 0.5,
+            ignore_id: int = -1,
+            lsm_weight: float = 0.0,
+            mask_ratio: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
+            preencoder: Optional[AbsPreEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -76,10 +77,9 @@
         self.ignore_id = ignore_id
         self.ctc_weight = ctc_weight
         self.token_list = token_list.copy()
-        
+
         self.mask_ratio = mask_ratio
 
-        
         self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
@@ -113,14 +113,13 @@
             self.error_calculator = None
 
     def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -130,22 +129,22 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-            speech.shape[0]
-            == speech_lengths.shape[0]
-            == text.shape[0]
-            == text_lengths.shape[0]
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-        #pdb.set_trace()
-        if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+        # pdb.set_trace()
+        if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
             rate_num = random.random()
-            #rate_num = 0.1
-            if(rate_num<=self.mask_ratio):
-                retain_channel = math.ceil(random.random() *8)
-                if(retain_channel>1):
-                    speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+            # rate_num = 0.1
+            if (rate_num <= self.mask_ratio):
+                retain_channel = math.ceil(random.random() * 8)
+                if (retain_channel > 1):
+                    speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
                 else:
-                    speech = speech[:,:,torch.randperm(8)[0]]
-        #pdb.set_trace()
+                    speech = speech[:, :, torch.randperm(8)[0]]
+        # pdb.set_trace()
         batch_size = speech.shape[0]
         # for data-parallel
         text = text[:, : text_lengths.max()]
@@ -195,20 +194,19 @@
         return loss, stats, weight
 
     def collect_feats(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -227,14 +225,14 @@
         # Pre-encoder, e.g. used for raw input data
         if self.preencoder is not None:
             feats, feats_lengths = self.preencoder(feats, feats_lengths)
-        #pdb.set_trace()
+        # pdb.set_trace()
         encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
             speech.size(0),
         )
-        if(encoder_out.dim()==4):
+        if (encoder_out.dim() == 4):
             assert encoder_out.size(2) <= encoder_out_lens.max(), (
                 encoder_out.size(),
                 encoder_out_lens.max(),
@@ -248,7 +246,7 @@
         return encoder_out, encoder_out_lens
 
     def _extract_feats(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
         # for data-parallel
@@ -266,11 +264,11 @@
         return feats, feats_lengths, channel_size
 
     def _calc_att_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
@@ -298,14 +296,14 @@
         return loss_att, acc_att, cer_att, wer_att
 
     def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
         # Calc CTC loss
-        if(encoder_out.dim()==4):
+        if (encoder_out.dim() == 4):
             encoder_out = encoder_out.mean(1)
         loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
 
@@ -317,10 +315,10 @@
         return loss_ctc, cer_ctc
 
     def _calc_rnnt_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
-        raise NotImplementedError
+        raise NotImplementedError
\ No newline at end of file

--
Gitblit v1.9.1