From dfa356a10c698e4e0548ab2d05ae31ab142bd4aa Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 11 四月 2023 00:27:54 +0800
Subject: [PATCH] update

---
 funasr/models/e2e_asr_paraformer.py |  189 ++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 160 insertions(+), 29 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 5786bc4..288f469 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -12,24 +12,20 @@
 import numpy as np
 from typeguard import check_argument_types
 
-from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
     LabelSmoothingLoss,  # noqa: H301
 )
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.e2e_asr_common import ErrorCalculator
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.predictor.cif import mae_loss
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.base_model import FunASRModel
 from funasr.modules.add_sos_eos import add_sos_eos
 from funasr.modules.nets_utils import make_pad_mask, pad_list
 from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.models.predictor.cif import CifPredictorV3
 
 
@@ -42,7 +38,7 @@
         yield
 
 
-class Paraformer(AbsESPnetModel):
+class Paraformer(FunASRModel):
     """
     Author: Speech Lab, Alibaba Group, China
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -53,11 +49,11 @@
             self,
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
+            frontend: Optional[torch.nn.Module],
+            specaug: Optional[torch.nn.Module],
+            normalize: Optional[torch.nn.Module],
             preencoder: Optional[AbsPreEncoder],
-            encoder: AbsEncoder,
+            encoder: torch.nn.Module,
             postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
@@ -325,12 +321,67 @@
 
         return encoder_out, encoder_out_lens
 
+    def encode_chunk(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        # 4. Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+                feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        # Post-encoder, e.g. NLU
+        if self.postencoder is not None:
+            encoder_out, encoder_out_lens = self.postencoder(
+                encoder_out, encoder_out_lens
+            )
+
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
+        return encoder_out, torch.tensor([encoder_out.size(1)])
+
     def calc_predictor(self, encoder_out, encoder_out_lens):
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
                                                                                   ignore_id=self.ignore_id)
+        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+    def calc_predictor_chunk(self, encoder_out, cache=None):
+
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
     def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -341,6 +392,14 @@
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
+
+    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+        decoder_outs = self.decoder.forward_chunk(
+            encoder_out, sematic_embeds, cache["decoder"]
+        )
+        decoder_out = decoder_outs
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out
 
     def _extract_feats(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
@@ -557,11 +616,11 @@
             self,
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
+            frontend: Optional[torch.nn.Module],
+            specaug: Optional[torch.nn.Module],
+            normalize: Optional[torch.nn.Module],
             preencoder: Optional[AbsPreEncoder],
-            encoder: AbsEncoder,
+            encoder: torch.nn.Module,
             postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
@@ -835,11 +894,11 @@
         self,
         vocab_size: int,
         token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
+        frontend: Optional[torch.nn.Module],
+        specaug: Optional[torch.nn.Module],
+        normalize: Optional[torch.nn.Module],
         preencoder: Optional[AbsPreEncoder],
-        encoder: AbsEncoder,
+        encoder: torch.nn.Module,
         postencoder: Optional[AbsPostEncoder],
         decoder: AbsDecoder,
         ctc: CTC,
@@ -926,10 +985,10 @@
     def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
+        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                                encoder_out_mask,
                                                                                                token_num)
-        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
 
     def forward(
             self,
@@ -962,21 +1021,82 @@
 
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
 
+        loss_att, acc_att, cer_att, wer_att = None, None, None, None
+        loss_ctc, cer_ctc = None, None
+        loss_pre = None
         stats = dict()
+
+        # 1. CTC branch
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+            # Collect CTC branch stats
+            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+            stats["cer_ctc"] = cer_ctc
+
+        # Intermediate CTC (optional)
+        loss_interctc = 0.0
+        if self.interctc_weight != 0.0 and intermediate_outs is not None:
+            for layer_idx, intermediate_out in intermediate_outs:
+                # we assume intermediate_out has the same length & padding
+                # as those of encoder_out
+                loss_ic, cer_ic = self._calc_ctc_loss(
+                    intermediate_out, encoder_out_lens, text, text_lengths
+                )
+                loss_interctc = loss_interctc + loss_ic
+
+                # Collect Intermedaite CTC stats
+                stats["loss_interctc_layer{}".format(layer_idx)] = (
+                    loss_ic.detach() if loss_ic is not None else None
+                )
+                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+            loss_interctc = loss_interctc / len(intermediate_outs)
+
+            # calculate whole encoder loss
+            loss_ctc = (
+                               1 - self.interctc_weight
+                       ) * loss_ctc + self.interctc_weight * loss_interctc
+
+        # 2b. Attention decoder branch
+        if self.ctc_weight != 1.0:
+            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
 
         loss_pre2 = self._calc_pre2_loss(
             encoder_out, encoder_out_lens, text, text_lengths
         )
 
-        loss = loss_pre2
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+        elif self.ctc_weight == 1.0:
+            loss = loss_ctc
+        else:
+            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
 
+        # Collect Attn branch stats
+        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["acc"] = acc_att
+        stats["cer"] = cer_att
+        stats["wer"] = wer_att
+        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
         stats["loss_pre2"] = loss_pre2.detach().cpu()
+
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
+
 
 class ContextualParaformer(Paraformer):
     """
@@ -987,11 +1107,11 @@
             self,
             vocab_size: int,
             token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
+            frontend: Optional[torch.nn.Module],
+            specaug: Optional[torch.nn.Module],
+            normalize: Optional[torch.nn.Module],
             preencoder: Optional[AbsPreEncoder],
-            encoder: AbsEncoder,
+            encoder: torch.nn.Module,
             postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
@@ -1021,6 +1141,7 @@
             inner_dim: int = 256,
             bias_encoder_type: str = 'lstm',
             label_bracket: bool = False,
+            use_decoder_embedding: bool = False,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -1074,6 +1195,7 @@
             self.hotword_buffer = None
             self.length_record = []
             self.current_buffer_length = 0
+        self.use_decoder_embedding = use_decoder_embedding
 
     def forward(
             self,
@@ -1215,7 +1337,10 @@
                     hw_list.append(hw_tokens)
         # padding
         hw_list_pad = pad_list(hw_list, 0)
-        hw_embed = self.decoder.embed(hw_list_pad)
+        if self.use_decoder_embedding:
+            hw_embed = self.decoder.embed(hw_list_pad)
+        else:
+            hw_embed = self.bias_embed(hw_list_pad)
         hw_embed, (_, _) = self.bias_encoder(hw_embed)
         _ind = np.arange(0, len(hw_list)).tolist()
         # update self.hotword_buffer, throw a part if oversize
@@ -1331,13 +1456,19 @@
             # default hotword list
             hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)]  # empty hotword list
             hw_list_pad = pad_list(hw_list, 0)
-            hw_embed = self.bias_embed(hw_list_pad)
+            if self.use_decoder_embedding:
+                hw_embed = self.decoder.embed(hw_list_pad)
+            else:
+                hw_embed = self.bias_embed(hw_list_pad)
             _, (h_n, _) = self.bias_encoder(hw_embed)
             contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
         else:
             hw_lengths = [len(i) for i in hw_list]
             hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
-            hw_embed = self.bias_embed(hw_list_pad)
+            if self.use_decoder_embedding:
+                hw_embed = self.decoder.embed(hw_list_pad)
+            else:
+                hw_embed = self.bias_embed(hw_list_pad)
             hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
                                                                enforce_sorted=False)
             _, (h_n, _) = self.bias_encoder(hw_embed)
@@ -1458,4 +1589,4 @@
                     "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                   var_dict_tf[name_tf].shape))
 
-        return var_dict_torch_update
\ No newline at end of file
+        return var_dict_torch_update

--
Gitblit v1.9.1