From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/monotonic_aligner/model.py |  136 +++++++++++++++++++++++++++------------------
 1 files changed, 81 insertions(+), 55 deletions(-)

diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
index 77d95a0..4754d1f 100644
--- a/funasr/models/monotonic_aligner/model.py
+++ b/funasr/models/monotonic_aligner/model.py
@@ -28,6 +28,7 @@
     Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
     https://arxiv.org/abs/2301.12343
     """
+
     def __init__(
         self,
         input_size: int = 80,
@@ -64,11 +65,11 @@
         self.predictor_bias = predictor_bias
 
     def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -80,25 +81,25 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
+            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
         batch_size = speech.shape[0]
         # for data-parallel
         text = text[:, : text_lengths.max()]
-        speech = speech[:, :speech_lengths.max()]
+        speech = speech[:, : speech_lengths.max()]
 
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-            encoder_out.device)
+        encoder_out_mask = (
+            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+        ).to(encoder_out.device)
         if self.predictor_bias == 1:
             _, text = add_sos_eos(text, 1, 2, -1)
             text_lengths = text_lengths + self.predictor_bias
-        _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
+        _, _, _, _, pre_token_length2 = self.predictor(
+            encoder_out, text, encoder_out_mask, ignore_id=-1
+        )
 
         # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
         loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
@@ -115,15 +116,19 @@
         return loss, stats, weight
 
     def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
-        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-            encoder_out.device)
-        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-                                                                                            encoder_out_mask,
-                                                                                            token_num)
+        encoder_out_mask = (
+            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+        ).to(encoder_out.device)
+        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
+            encoder_out, encoder_out_mask, token_num
+        )
         return ds_alphas, ds_cif_peak, us_alphas, us_peaks
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -136,70 +141,91 @@
             # Data augmentation
             if self.specaug is not None and self.training:
                 speech, speech_lengths = self.specaug(speech, speech_lengths)
-            
+
             # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
             if self.normalize is not None:
                 speech, speech_lengths = self.normalize(speech, speech_lengths)
-        
+
         # Forward encoder
         encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
 
         return encoder_out, encoder_out_lens
-    
-    def inference(self,
-             data_in,
-             data_lengths=None,
-             key: list=None,
-             tokenizer=None,
-             frontend=None,
-             **kwargs,
-             ):
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
         meta_data = {}
         # extract fbank feats
         time1 = time.perf_counter()
-        audio_list, text_token_int_list = load_audio_text_image_video(data_in, 
-                                                                            fs=frontend.fs, 
-                                                                            audio_fs=kwargs.get("fs", 16000), 
-                                                                            data_type=kwargs.get("data_type", "sound"), 
-                                                                            tokenizer=tokenizer)
+        audio_list, text_token_int_list = load_audio_text_image_video(
+            data_in,
+            fs=frontend.fs,
+            audio_fs=kwargs.get("fs", 16000),
+            data_type=kwargs.get("data_type", "sound"),
+            tokenizer=tokenizer,
+        )
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
-        speech, speech_lengths = extract_fbank(audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+        speech, speech_lengths = extract_fbank(
+            audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+        )
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-        meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-            
+        meta_data["batch_data_time"] = (
+            speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+        )
+
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
 
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        
+
         # predictor
-        text_lengths = torch.tensor([len(i)+1 for i in text_token_int_list]).to(encoder_out.device)
-        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num=text_lengths)
-        
+        text_lengths = torch.tensor([len(i) + 1 for i in text_token_int_list]).to(
+            encoder_out.device
+        )
+        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(
+            encoder_out, encoder_out_lens, token_num=text_lengths
+        )
+
         results = []
         ibest_writer = None
-        if ibest_writer is None and kwargs.get("output_dir") is not None:
-            writer = DatadirWriter(kwargs.get("output_dir"))
-            ibest_writer = writer["tp_res"]
-        for i, (us_alpha, us_peak, token_int) in enumerate(zip(us_alphas, us_peaks, text_token_int_list)):
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer["tp_res"]
+
+        for i, (us_alpha, us_peak, token_int) in enumerate(
+            zip(us_alphas, us_peaks, text_token_int_list)
+        ):
             token = tokenizer.ids2tokens(token_int)
-            timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
-                                                                   us_peak[:encoder_out_lens[i] * 3],
-                                                                   copy.copy(token))
-            text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
-            result_i = {"key": key[i], "text": text_postprocessed,
-                                "timestamp": time_stamp_postprocessed,
-                                }
+            timestamp_str, timestamp = ts_prediction_lfr6_standard(
+                us_alpha[: encoder_out_lens[i] * 3],
+                us_peak[: encoder_out_lens[i] * 3],
+                copy.copy(token),
+            )
+            text_postprocessed, time_stamp_postprocessed, _ = (
+                postprocess_utils.sentence_postprocess(token, timestamp)
+            )
+            result_i = {
+                "key": key[i],
+                "text": text_postprocessed,
+                "timestamp": time_stamp_postprocessed,
+            }
             results.append(result_i)
 
             if ibest_writer:
                 # ibest_writer["token"][key[i]] = " ".join(token)
                 ibest_writer["timestamp_list"][key[i]] = time_stamp_postprocessed
                 ibest_writer["timestamp_str"][key[i]] = timestamp_str
-            
-        return results, meta_data
\ No newline at end of file
+
+        return results, meta_data

--
Gitblit v1.9.1