From e6fe602db3eb1209543e55f1aafa2932dfda3310 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 一月 2025 10:14:30 +0800
Subject: [PATCH] step_or_epoch bugfix

---
 funasr/models/uniasr/model.py |  433 ++++++++++++++++++++++++++++++++---------------------
 1 files changed, 258 insertions(+), 175 deletions(-)

diff --git a/funasr/models/uniasr/model.py b/funasr/models/uniasr/model.py
index 6e564dc..bde6377 100644
--- a/funasr/models/uniasr/model.py
+++ b/funasr/models/uniasr/model.py
@@ -22,6 +22,7 @@
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 from funasr.models.scama.utils import sequence_mask
 
+
 @tables.register("model_classes", "UniASR")
 class UniASR(torch.nn.Module):
     """
@@ -56,8 +57,8 @@
         ctc2: str = None,
         ctc2_conf: dict = None,
         ctc2_weight: float = 0.5,
-        decoder_attention_chunk_type: str = 'chunk',
-        decoder_attention_chunk_type2: str = 'chunk',
+        decoder_attention_chunk_type: str = "chunk",
+        decoder_attention_chunk_type2: str = "chunk",
         stride_conv=None,
         stride_conv_conf: dict = None,
         loss_weight_model1: float = 0.5,
@@ -71,7 +72,6 @@
         length_normalized_loss: bool = False,
         share_embedding: bool = False,
         **kwargs,
-        
     ):
         super().__init__()
 
@@ -81,7 +81,7 @@
         if normalize is not None:
             normalize_class = tables.normalize_classes.get(normalize)
             normalize = normalize_class(**normalize_conf)
-            
+
         encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
@@ -94,12 +94,14 @@
         )
         predictor_class = tables.predictor_classes.get(predictor)
         predictor = predictor_class(**predictor_conf)
-        
 
-        
         from funasr.models.transformer.utils.subsampling import Conv1dSubsampling
-        stride_conv = Conv1dSubsampling(**stride_conv_conf, idim=input_size + encoder_output_size,
-                                        odim=input_size + encoder_output_size)
+
+        stride_conv = Conv1dSubsampling(
+            **stride_conv_conf,
+            idim=input_size + encoder_output_size,
+            odim=input_size + encoder_output_size,
+        )
         stride_conv_output_size = stride_conv.output_size()
 
         encoder_class = tables.encoder_classes.get(encoder2)
@@ -115,8 +117,6 @@
         predictor_class = tables.predictor_classes.get(predictor2)
         predictor2 = predictor_class(**predictor2_conf)
 
-
-        
         self.blank_id = blank_id
         self.sos = sos
         self.eos = eos
@@ -127,7 +127,7 @@
 
         self.specaug = specaug
         self.normalize = normalize
-        
+
         self.encoder = encoder
 
         self.error_calculator = None
@@ -142,16 +142,20 @@
             smoothing=lsm_weight,
             normalize_length=length_normalized_loss,
         )
-        
+
         self.predictor = predictor
         self.predictor_weight = predictor_weight
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
         self.encoder1_encoder2_joint_training = kwargs.get("encoder1_encoder2_joint_training", True)
-        
 
         if self.encoder.overlap_chunk_cls is not None:
-            from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
-            self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+            from funasr.models.scama.chunk_utilis import (
+                build_scama_mask_for_cross_attention_decoder,
+            )
+
+            self.build_scama_mask_for_cross_attention_decoder_fn = (
+                build_scama_mask_for_cross_attention_decoder
+            )
             self.decoder_attention_chunk_type = decoder_attention_chunk_type
 
         self.encoder2 = encoder2
@@ -164,8 +168,13 @@
         self.stride_conv = stride_conv
         self.loss_weight_model1 = loss_weight_model1
         if self.encoder2.overlap_chunk_cls is not None:
-            from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
-            self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
+            from funasr.models.scama.chunk_utilis import (
+                build_scama_mask_for_cross_attention_decoder,
+            )
+
+            self.build_scama_mask_for_cross_attention_decoder_fn2 = (
+                build_scama_mask_for_cross_attention_decoder
+            )
             self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
 
         self.length_normalized_loss = length_normalized_loss
@@ -196,15 +205,15 @@
 
         batch_size = speech.shape[0]
 
-
         ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
         # 1. Encoder
         if self.enable_maas_finetune:
             with torch.no_grad():
-                speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+                speech_raw, encoder_out, encoder_out_lens = self.encode(
+                    speech, speech_lengths, ind=ind
+                )
         else:
             speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
-
 
         loss_att, acc_att, cer_att, wer_att = None, None, None, None
         loss_ctc, cer_ctc = None, None
@@ -231,11 +240,10 @@
                     stats["wer"] = wer_att
                     stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
             else:
-                
+
                 loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
                     encoder_out, encoder_out_lens, text, text_lengths
                 )
-
 
                 loss = loss_att + loss_pre * self.predictor_weight
 
@@ -254,20 +262,22 @@
             # encoder2
             if self.freeze_encoder2:
                 with torch.no_grad():
-                    encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
+                    encoder_out, encoder_out_lens = self.encode2(
+                        encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind
+                    )
             else:
-                encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
+                encoder_out, encoder_out_lens = self.encode2(
+                    encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind
+                )
 
             intermediate_outs = None
             if isinstance(encoder_out, tuple):
                 intermediate_outs = encoder_out[1]
                 encoder_out = encoder_out[0]
 
-
             loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
                 encoder_out, encoder_out_lens, text, text_lengths
             )
-
 
             loss = loss_att + loss_pre * self.predictor2_weight
 
@@ -277,7 +287,7 @@
             stats["cer2"] = cer_att
             stats["wer2"] = wer_att
             stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
-        
+
         loss2 = loss
 
         loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
@@ -287,7 +297,6 @@
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
             batch_size = int((text_lengths + 1).sum())
-
 
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
@@ -312,7 +321,10 @@
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
     ):
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -324,13 +336,12 @@
             # Data augmentation
             if self.specaug is not None and self.training:
                 speech, speech_lengths = self.specaug(speech, speech_lengths)
-    
+
             # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
             if self.normalize is not None:
                 speech, speech_lengths = self.normalize(speech, speech_lengths)
-                
-        speech_raw = speech.clone().to(speech.device)
 
+        speech_raw = speech.clone().to(speech.device)
 
         # 4. Forward encoder
         encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
@@ -375,9 +386,7 @@
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
 
-
         return encoder_out, encoder_out_lens
-
 
     def nll(
         self,
@@ -472,9 +481,7 @@
         ys_in_lens = ys_pad_lens + 1
 
         # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
+        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
 
         # 2. Compute attention loss
         loss_att = self.criterion_att(decoder_out, ys_out_pad)
@@ -503,37 +510,49 @@
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
 
-        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
-                                         device=encoder_out.device)[:, None, :]
+        encoder_out_mask = sequence_mask(
+            encoder_out_lens,
+            maxlen=encoder_out.size(1),
+            dtype=encoder_out.dtype,
+            device=encoder_out.device,
+        )[:, None, :]
         mask_chunk_predictor = None
         if self.encoder.overlap_chunk_cls is not None:
-            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
-                                                                                           device=encoder_out.device,
-                                                                                           batch_size=encoder_out.size(
-                                                                                               0))
-            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
-                                                                                   batch_size=encoder_out.size(0))
+            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
+            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
             encoder_out = encoder_out * mask_shfit_chunk
-        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
-                                                                              ys_out_pad,
-                                                                              encoder_out_mask,
-                                                                              ignore_id=self.ignore_id,
-                                                                              mask_chunk_predictor=mask_chunk_predictor,
-                                                                              target_label_length=ys_in_lens,
-                                                                              )
-        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
-                                                                                             encoder_out_lens)
+        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(
+            encoder_out,
+            ys_out_pad,
+            encoder_out_mask,
+            ignore_id=self.ignore_id,
+            mask_chunk_predictor=mask_chunk_predictor,
+            target_label_length=ys_in_lens,
+        )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(
+            pre_alphas, encoder_out_lens
+        )
 
         scama_mask = None
-        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+        if (
+            self.encoder.overlap_chunk_cls is not None
+            and self.decoder_attention_chunk_type == "chunk"
+        ):
             encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
             attention_chunk_center_bias = 0
             attention_chunk_size = encoder_chunk_size
-            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
-            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
-                                                                                                           device=encoder_out.device,
-                                                                                                           batch_size=encoder_out.size(
-                                                                                                               0))
+            decoder_att_look_back_factor = (
+                self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            )
+            mask_shift_att_chunk_decoder = (
+                self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(
+                    None, device=encoder_out.device, batch_size=encoder_out.size(0)
+                )
+            )
             scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
                 predictor_alignments=predictor_alignments,
                 encoder_sequence_length=encoder_out_lens,
@@ -550,8 +569,9 @@
                 is_training=self.training,
             )
         elif self.encoder.overlap_chunk_cls is not None:
-            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
-                                                                                        chunk_outs=None)
+            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(
+                encoder_out, encoder_out_lens, chunk_outs=None
+            )
         # try:
         # 1. Forward decoder
         decoder_out, _ = self.decoder(
@@ -561,7 +581,6 @@
             ys_in_lens,
             chunk_mask=scama_mask,
             pre_acoustic_embeds=pre_acoustic_embeds,
-
         )
 
         # 2. Compute attention loss
@@ -592,37 +611,49 @@
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
 
-        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
-                                         device=encoder_out.device)[:, None, :]
+        encoder_out_mask = sequence_mask(
+            encoder_out_lens,
+            maxlen=encoder_out.size(1),
+            dtype=encoder_out.dtype,
+            device=encoder_out.device,
+        )[:, None, :]
         mask_chunk_predictor = None
         if self.encoder2.overlap_chunk_cls is not None:
-            mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
-                                                                                            device=encoder_out.device,
-                                                                                            batch_size=encoder_out.size(
-                                                                                                0))
-            mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
-                                                                                    batch_size=encoder_out.size(0))
+            mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
+            mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
             encoder_out = encoder_out * mask_shfit_chunk
-        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
-                                                                               ys_out_pad,
-                                                                               encoder_out_mask,
-                                                                               ignore_id=self.ignore_id,
-                                                                               mask_chunk_predictor=mask_chunk_predictor,
-                                                                               target_label_length=ys_in_lens,
-                                                                               )
-        predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
-                                                                                              encoder_out_lens)
+        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(
+            encoder_out,
+            ys_out_pad,
+            encoder_out_mask,
+            ignore_id=self.ignore_id,
+            mask_chunk_predictor=mask_chunk_predictor,
+            target_label_length=ys_in_lens,
+        )
+        predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(
+            pre_alphas, encoder_out_lens
+        )
 
         scama_mask = None
-        if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
+        if (
+            self.encoder2.overlap_chunk_cls is not None
+            and self.decoder_attention_chunk_type2 == "chunk"
+        ):
             encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
             attention_chunk_center_bias = 0
             attention_chunk_size = encoder_chunk_size
-            decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
-            mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
-                                                                                                            device=encoder_out.device,
-                                                                                                            batch_size=encoder_out.size(
-                                                                                                                0))
+            decoder_att_look_back_factor = (
+                self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            )
+            mask_shift_att_chunk_decoder = (
+                self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(
+                    None, device=encoder_out.device, batch_size=encoder_out.size(0)
+                )
+            )
             scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
                 predictor_alignments=predictor_alignments,
                 encoder_sequence_length=encoder_out_lens,
@@ -639,8 +670,9 @@
                 is_training=self.training,
             )
         elif self.encoder2.overlap_chunk_cls is not None:
-            encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
-                                                                                         chunk_outs=None)
+            encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(
+                encoder_out, encoder_out_lens, chunk_outs=None
+            )
         # try:
         # 1. Forward decoder
         decoder_out, _ = self.decoder2(
@@ -681,37 +713,49 @@
         # ys_in_lens = ys_pad_lens + 1
         ys_out_pad, ys_in_lens = None, None
 
-        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
-                                         device=encoder_out.device)[:, None, :]
+        encoder_out_mask = sequence_mask(
+            encoder_out_lens,
+            maxlen=encoder_out.size(1),
+            dtype=encoder_out.dtype,
+            device=encoder_out.device,
+        )[:, None, :]
         mask_chunk_predictor = None
         if self.encoder.overlap_chunk_cls is not None:
-            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
-                                                                                           device=encoder_out.device,
-                                                                                           batch_size=encoder_out.size(
-                                                                                               0))
-            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
-                                                                                   batch_size=encoder_out.size(0))
+            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
+            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
             encoder_out = encoder_out * mask_shfit_chunk
-        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
-                                                                              ys_out_pad,
-                                                                              encoder_out_mask,
-                                                                              ignore_id=self.ignore_id,
-                                                                              mask_chunk_predictor=mask_chunk_predictor,
-                                                                              target_label_length=ys_in_lens,
-                                                                              )
-        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
-                                                                                             encoder_out_lens)
+        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(
+            encoder_out,
+            ys_out_pad,
+            encoder_out_mask,
+            ignore_id=self.ignore_id,
+            mask_chunk_predictor=mask_chunk_predictor,
+            target_label_length=ys_in_lens,
+        )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(
+            pre_alphas, encoder_out_lens
+        )
 
         scama_mask = None
-        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+        if (
+            self.encoder.overlap_chunk_cls is not None
+            and self.decoder_attention_chunk_type == "chunk"
+        ):
             encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
             attention_chunk_center_bias = 0
             attention_chunk_size = encoder_chunk_size
-            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
-            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
-                                                                                                           device=encoder_out.device,
-                                                                                                           batch_size=encoder_out.size(
-                                                                                                               0))
+            decoder_att_look_back_factor = (
+                self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            )
+            mask_shift_att_chunk_decoder = (
+                self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(
+                    None, device=encoder_out.device, batch_size=encoder_out.size(0)
+                )
+            )
             scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
                 predictor_alignments=predictor_alignments,
                 encoder_sequence_length=encoder_out_lens,
@@ -728,10 +772,17 @@
                 is_training=self.training,
             )
         elif self.encoder.overlap_chunk_cls is not None:
-            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
-                                                                                        chunk_outs=None)
+            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(
+                encoder_out, encoder_out_lens, chunk_outs=None
+            )
 
-        return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
+        return (
+            pre_acoustic_embeds,
+            pre_token_length,
+            predictor_alignments,
+            predictor_alignments_len,
+            scama_mask,
+        )
 
     def calc_predictor_mask2(
         self,
@@ -744,37 +795,49 @@
         # ys_in_lens = ys_pad_lens + 1
         ys_out_pad, ys_in_lens = None, None
 
-        encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
-                                         device=encoder_out.device)[:, None, :]
+        encoder_out_mask = sequence_mask(
+            encoder_out_lens,
+            maxlen=encoder_out.size(1),
+            dtype=encoder_out.dtype,
+            device=encoder_out.device,
+        )[:, None, :]
         mask_chunk_predictor = None
         if self.encoder2.overlap_chunk_cls is not None:
-            mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
-                                                                                            device=encoder_out.device,
-                                                                                            batch_size=encoder_out.size(
-                                                                                                0))
-            mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
-                                                                                    batch_size=encoder_out.size(0))
+            mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
+            mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(
+                None, device=encoder_out.device, batch_size=encoder_out.size(0)
+            )
             encoder_out = encoder_out * mask_shfit_chunk
-        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
-                                                                               ys_out_pad,
-                                                                               encoder_out_mask,
-                                                                               ignore_id=self.ignore_id,
-                                                                               mask_chunk_predictor=mask_chunk_predictor,
-                                                                               target_label_length=ys_in_lens,
-                                                                               )
-        predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
-                                                                                              encoder_out_lens)
+        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(
+            encoder_out,
+            ys_out_pad,
+            encoder_out_mask,
+            ignore_id=self.ignore_id,
+            mask_chunk_predictor=mask_chunk_predictor,
+            target_label_length=ys_in_lens,
+        )
+        predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(
+            pre_alphas, encoder_out_lens
+        )
 
         scama_mask = None
-        if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
+        if (
+            self.encoder2.overlap_chunk_cls is not None
+            and self.decoder_attention_chunk_type2 == "chunk"
+        ):
             encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
             attention_chunk_center_bias = 0
             attention_chunk_size = encoder_chunk_size
-            decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
-            mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
-                                                                                                            device=encoder_out.device,
-                                                                                                            batch_size=encoder_out.size(
-                                                                                                                0))
+            decoder_att_look_back_factor = (
+                self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            )
+            mask_shift_att_chunk_decoder = (
+                self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(
+                    None, device=encoder_out.device, batch_size=encoder_out.size(0)
+                )
+            )
             scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
                 predictor_alignments=predictor_alignments,
                 encoder_sequence_length=encoder_out_lens,
@@ -791,14 +854,22 @@
                 is_training=self.training,
             )
         elif self.encoder2.overlap_chunk_cls is not None:
-            encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
-                                                                                         chunk_outs=None)
+            encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(
+                encoder_out, encoder_out_lens, chunk_outs=None
+            )
 
-        return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
+        return (
+            pre_acoustic_embeds,
+            pre_token_length,
+            predictor_alignments,
+            predictor_alignments_len,
+            scama_mask,
+        )
 
-    def init_beam_search(self,
-                         **kwargs,
-                         ):
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
         from funasr.models.uniasr.beam_search import BeamSearchScama
         from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
         from funasr.models.transformer.scorers.length_bonus import LengthBonus
@@ -810,23 +881,21 @@
             decoder = self.decoder2
         # 1. Build ASR model
         scorers = {}
-    
+
         if self.ctc != None:
             ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
-            scorers.update(
-                ctc=ctc
-            )
+            scorers.update(ctc=ctc)
         token_list = kwargs.get("token_list")
         scorers.update(
             decoder=decoder,
             length_bonus=LengthBonus(len(token_list)),
         )
-    
+
         # 3. Build ngram model
         # ngram is not supported now
         ngram = None
         scorers["ngram"] = ngram
-    
+
         weights = dict(
             decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
             ctc=kwargs.get("decoding_ctc_weight", 0.0),
@@ -844,17 +913,18 @@
             token_list=token_list,
             pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
         )
-        
+
         self.beam_search = beam_search
 
-    def inference(self,
-                  data_in,
-                  data_lengths=None,
-                  key: list = None,
-                  tokenizer=None,
-                  frontend=None,
-                  **kwargs,
-                  ):
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
 
         decoding_model = kwargs.get("decoding_model", "normal")
         token_num_relax = kwargs.get("token_num_relax", 5)
@@ -868,14 +938,16 @@
             decoding_ind = 0
             decoding_mode = "model2"
         # init beamsearch
-        
+
         if self.beam_search is None:
             logging.info("enable beam_search")
             self.init_beam_search(decoding_mode=decoding_mode, **kwargs)
             self.nbest = kwargs.get("nbest", 1)
-    
+
         meta_data = {}
-        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
             speech, speech_lengths = data_in, data_lengths
             if len(speech.shape) < 3:
                 speech = speech[None, :, :]
@@ -884,17 +956,24 @@
         else:
             # extract fbank feats
             time1 = time.perf_counter()
-            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
-                                                            data_type=kwargs.get("data_type", "sound"),
-                                                            tokenizer=tokenizer)
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
-            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend)
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
             time3 = time.perf_counter()
             meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-    
+            meta_data["batch_data_time"] = (
+                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            )
+
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
         speech_raw = speech.clone().to(device=kwargs["device"])
@@ -903,9 +982,10 @@
         if decoding_mode == "model1":
             predictor_outs = self.calc_predictor_mask(encoder_out, encoder_out_lens)
         else:
-            encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind)
+            encoder_out, encoder_out_lens = self.encode2(
+                encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind
+            )
             predictor_outs = self.calc_predictor_mask2(encoder_out, encoder_out_lens)
-
 
         scama_mask = predictor_outs[4]
         pre_token_length = predictor_outs[1]
@@ -914,8 +994,13 @@
         minlen = max(0, pre_token_length.sum().item() - token_num_relax)
         # c. Passed the encoder result and the beam search
         nbest_hyps = self.beam_search(
-            x=encoder_out[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=0.0,
-            minlenratio=0.0, maxlen=int(maxlen), minlen=int(minlen),
+            x=encoder_out[0],
+            scama_mask=scama_mask,
+            pre_acoustic_embeds=pre_acoustic_embeds,
+            maxlenratio=0.0,
+            minlenratio=0.0,
+            maxlen=int(maxlen),
+            minlen=int(minlen),
         )
 
         nbest_hyps = nbest_hyps[: self.nbest]
@@ -933,15 +1018,13 @@
             # remove blank symbol id, which is assumed to be 0
             token_int = list(filter(lambda x: x != 0, token_int))
 
-
             # Change integer-ids to tokens
             token = tokenizer.ids2tokens(token_int)
             text_postprocessed = tokenizer.tokens2text(token)
             if not hasattr(tokenizer, "bpemodel"):
                 text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-    
 
             result_i = {"key": key[0], "text": text_postprocessed}
             results.append(result_i)
 
-        return results, meta_data
\ No newline at end of file
+        return results, meta_data

--
Gitblit v1.9.1