From acb9a0fec8d8a4dabeedcbb8e08c26f66d7083f0 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 08 十二月 2023 16:19:00 +0800
Subject: [PATCH] fix loss normalization for ddp training

---
 funasr/models/e2e_asr_paraformer.py |  151 ++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 126 insertions(+), 25 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 54db971..0e0b95b 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -10,7 +10,6 @@
 import torch
 import random
 import numpy as np
-from typeguard import check_argument_types
 
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
@@ -80,7 +79,6 @@
             postencoder: Optional[AbsPostEncoder] = None,
             use_1st_decoder_loss: bool = False,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -139,6 +137,7 @@
         self.predictor_bias = predictor_bias
         self.sampling_ratio = sampling_ratio
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+        self.length_normalized_loss = length_normalized_loss
         self.step_cur = 0
 
         self.share_embedding = share_embedding
@@ -242,7 +241,7 @@
             loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
 
         if self.use_1st_decoder_loss and pre_loss_att is not None:
-            loss = loss + pre_loss_att
+            loss = loss + (1 - self.ctc_weight) * pre_loss_att
 
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -255,6 +254,8 @@
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -279,7 +280,7 @@
 
     def encode(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
-    ) -> Tuple[Tuple[Any, Optional[Any]], Any]:
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
                 speech: (Batch, Length, ...)
@@ -354,8 +355,9 @@
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+                                                                                       encoder_out_mask,
+                                                                                       ignore_id=self.ignore_id)
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
     def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -489,8 +491,9 @@
             if self.step_cur < 2:
                 logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
             if self.use_1st_decoder_loss:
-                sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
-                                                               pre_acoustic_embeds)
+                sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens,
+                                                                                       ys_pad, ys_pad_lens,
+                                                                                       pre_acoustic_embeds)
             else:
                 sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
                                                                pre_acoustic_embeds)
@@ -645,11 +648,38 @@
             postencoder: Optional[AbsPostEncoder] = None,
             use_1st_decoder_loss: bool = False,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
-        super().__init__()
+        super().__init__(
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+            postencoder=postencoder,
+            decoder=decoder,
+            ctc=ctc,
+            ctc_weight=ctc_weight,
+            interctc_weight=interctc_weight,
+            ignore_id=ignore_id,
+            blank_id=blank_id,
+            sos=sos,
+            eos=eos,
+            lsm_weight=lsm_weight,
+            length_normalized_loss=length_normalized_loss,
+            report_cer=report_cer,
+            report_wer=report_wer,
+            sym_space=sym_space,
+            sym_blank=sym_blank,
+            extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+            predictor=predictor,
+            predictor_weight=predictor_weight,
+            predictor_bias=predictor_bias,
+            sampling_ratio=sampling_ratio,
+        )
         # note that eos is the same as sos (equivalent ID)
         self.blank_id = blank_id
         self.sos = vocab_size - 1 if sos is None else sos
@@ -702,9 +732,11 @@
         self.predictor = predictor
         self.predictor_weight = predictor_weight
         self.predictor_bias = predictor_bias
+        self.length_normalized_loss = length_normalized_loss
         self.sampling_ratio = sampling_ratio
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
         self.step_cur = 0
+        self.scama_mask = None
         if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
             from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
             self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
@@ -834,11 +866,13 @@
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -944,11 +978,11 @@
         return encoder_out, torch.tensor([encoder_out.size(1)])
 
     def _calc_att_predictor_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
@@ -980,7 +1014,7 @@
             attention_chunk_center_bias = 0
             attention_chunk_size = encoder_chunk_size
             decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
-            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
+            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
                 get_mask_shift_att_chunk_decoder(None,
                                                  device=encoder_out.device,
                                                  batch_size=encoder_out.size(0)
@@ -1080,7 +1114,8 @@
             input_mask_expand_dim, 0)
         return sematic_embeds * tgt_mask, decoder_out * tgt_mask
 
-    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds,
+                          chunk_mask=None):
         tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
         ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
         if self.share_embedding:
@@ -1111,11 +1146,72 @@
 
         return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
 
+    def calc_predictor(self, encoder_out, encoder_out_lens):
+
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        mask_chunk_predictor = None
+        if self.encoder.overlap_chunk_cls is not None:
+            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+                                                                                           device=encoder_out.device,
+                                                                                           batch_size=encoder_out.size(
+                                                                                               0))
+            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+                                                                                   batch_size=encoder_out.size(0))
+            encoder_out = encoder_out * mask_shfit_chunk
+        pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+                                                                                           None,
+                                                                                           encoder_out_mask,
+                                                                                           ignore_id=self.ignore_id,
+                                                                                           mask_chunk_predictor=mask_chunk_predictor,
+                                                                                           target_label_length=None,
+                                                                                           )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+                                                                                             encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+
+        scama_mask = None
+        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+            encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+            attention_chunk_center_bias = 0
+            attention_chunk_size = encoder_chunk_size
+            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+                get_mask_shift_att_chunk_decoder(None,
+                                                 device=encoder_out.device,
+                                                 batch_size=encoder_out.size(0)
+                                                 )
+            scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+                predictor_alignments=predictor_alignments,
+                encoder_sequence_length=encoder_out_lens,
+                chunk_size=1,
+                encoder_chunk_size=encoder_chunk_size,
+                attention_chunk_center_bias=attention_chunk_center_bias,
+                attention_chunk_size=attention_chunk_size,
+                attention_chunk_type=self.decoder_attention_chunk_type,
+                step=None,
+                predictor_mask_chunk_hopping=mask_chunk_predictor,
+                decoder_att_look_back_factor=decoder_att_look_back_factor,
+                mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+                target_length=None,
+                is_training=self.training,
+            )
+        self.scama_mask = scama_mask
+
+        return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
     def calc_predictor_chunk(self, encoder_out, cache=None):
 
         pre_acoustic_embeds, pre_token_length = \
             self.predictor.forward_chunk(encoder_out, cache["encoder"])
         return pre_acoustic_embeds, pre_token_length
+
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+        )
+        decoder_out = decoder_outs[0]
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out, ys_pad_lens
 
     def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
         decoder_outs = self.decoder.forward_chunk(
@@ -1165,7 +1261,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1398,6 +1493,8 @@
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -1438,7 +1535,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1504,8 +1600,9 @@
         if self.predictor_bias == 1:
             _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
             ys_pad_lens = ys_pad_lens + self.predictor_bias
-        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+                                                                                     encoder_out_mask,
+                                                                                     ignore_id=self.ignore_id)
 
         # 0. sampler
         decoder_out_1st = None
@@ -1654,7 +1751,7 @@
             loss = loss_ctc
         else:
             loss = self.ctc_weight * loss_ctc + (
-                        1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+                    1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
 
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1667,6 +1764,8 @@
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -1716,7 +1815,6 @@
             preencoder: Optional[AbsPreEncoder] = None,
             postencoder: Optional[AbsPostEncoder] = None,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -1868,6 +1966,8 @@
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + self.predictor_bias).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -2023,7 +2123,8 @@
 
         return loss_att, acc_att, cer_att, wer_att, loss_pre
 
-    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
+                                   clas_scale=1.0):
         if hw_list is None:
             # default hotword list
             hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)]  # empty hotword list

--
Gitblit v1.9.1