From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/monotonic_aligner/model.py |  181 ++++++++++++++++++++++++++-------------------
 1 files changed, 105 insertions(+), 76 deletions(-)

diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
index ece319d..4754d1f 100644
--- a/funasr/models/monotonic_aligner/model.py
+++ b/funasr/models/monotonic_aligner/model.py
@@ -1,24 +1,27 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 import time
 import copy
 import torch
 from torch.cuda.amp import autocast
 from typing import Union, Dict, List, Tuple, Optional
 
-from funasr.models.paraformer.cif_predictor import mae_loss
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.metrics.compute_acc import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
 from funasr.register import tables
 from funasr.models.ctc.ctc import CTC
-from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 
-
-@tables.register("model_classes", "monotonicaligner")
+@tables.register("model_classes", "MonotonicAligner")
 class MonotonicAligner(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -41,19 +44,18 @@
         length_normalized_loss: bool = False,
         **kwargs,
     ):
-
         super().__init__()
 
         if specaug is not None:
-            specaug_class = tables.specaug_classes.get(specaug.lower())
+            specaug_class = tables.specaug_classes.get(specaug)
             specaug = specaug_class(**specaug_conf)
         if normalize is not None:
-            normalize_class = tables.normalize_classes.get(normalize.lower())
+            normalize_class = tables.normalize_classes.get(normalize)
             normalize = normalize_class(**normalize_conf)
-        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
-        predictor_class = tables.predictor_classes.get(predictor.lower())
+        predictor_class = tables.predictor_classes.get(predictor)
         predictor = predictor_class(**predictor_conf)
         self.specaug = specaug
         self.normalize = normalize
@@ -63,11 +65,11 @@
         self.predictor_bias = predictor_bias
 
     def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -79,25 +81,25 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
+            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
         batch_size = speech.shape[0]
         # for data-parallel
         text = text[:, : text_lengths.max()]
-        speech = speech[:, :speech_lengths.max()]
+        speech = speech[:, : speech_lengths.max()]
 
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-            encoder_out.device)
+        encoder_out_mask = (
+            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+        ).to(encoder_out.device)
         if self.predictor_bias == 1:
             _, text = add_sos_eos(text, 1, 2, -1)
             text_lengths = text_lengths + self.predictor_bias
-        _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
+        _, _, _, _, pre_token_length2 = self.predictor(
+            encoder_out, text, encoder_out_mask, ignore_id=-1
+        )
 
         # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
         loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
@@ -114,15 +116,19 @@
         return loss, stats, weight
 
     def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
-        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-            encoder_out.device)
-        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-                                                                                            encoder_out_mask,
-                                                                                            token_num)
+        encoder_out_mask = (
+            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+        ).to(encoder_out.device)
+        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(
+            encoder_out, encoder_out_mask, token_num
+        )
         return ds_alphas, ds_cif_peak, us_alphas, us_peaks
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -135,68 +141,91 @@
             # Data augmentation
             if self.specaug is not None and self.training:
                 speech, speech_lengths = self.specaug(speech, speech_lengths)
-            
+
             # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
             if self.normalize is not None:
                 speech, speech_lengths = self.normalize(speech, speech_lengths)
-        
+
         # Forward encoder
         encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
 
         return encoder_out, encoder_out_lens
-    
-    def generate(self,
-             data_in,
-             data_lengths=None,
-             key: list=None,
-             tokenizer=None,
-             frontend=None,
-             **kwargs,
-             ):
-        
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
         meta_data = {}
         # extract fbank feats
         time1 = time.perf_counter()
-        audio_list, text_token_int_list = load_audio_and_text_image_video(data_in, 
-                                                                            fs=frontend.fs, 
-                                                                            audio_fs=kwargs.get("fs", 16000), 
-                                                                            data_type=kwargs.get("data_type", "sound"), 
-                                                                            tokenizer=tokenizer)
+        audio_list, text_token_int_list = load_audio_text_image_video(
+            data_in,
+            fs=frontend.fs,
+            audio_fs=kwargs.get("fs", 16000),
+            data_type=kwargs.get("data_type", "sound"),
+            tokenizer=tokenizer,
+        )
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
-        speech, speech_lengths = extract_fbank(audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+        speech, speech_lengths = extract_fbank(
+            audio_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+        )
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-        meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-            
-        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+        meta_data["batch_data_time"] = (
+            speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+        )
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
 
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        
+
         # predictor
-        text_lengths = torch.tensor([len(i)+1 for i in text_token_int_list]).to(encoder_out.device)
-        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, token_num=text_lengths)
-        
+        text_lengths = torch.tensor([len(i) + 1 for i in text_token_int_list]).to(
+            encoder_out.device
+        )
+        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(
+            encoder_out, encoder_out_lens, token_num=text_lengths
+        )
+
         results = []
         ibest_writer = None
-        if ibest_writer is None and kwargs.get("output_dir") is not None:
-            writer = DatadirWriter(kwargs.get("output_dir"))
-            ibest_writer = writer["tp_res"]
-        for i, (us_alpha, us_peak, token_int) in enumerate(zip(us_alphas, us_peaks, text_token_int_list)):
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer["tp_res"]
+
+        for i, (us_alpha, us_peak, token_int) in enumerate(
+            zip(us_alphas, us_peaks, text_token_int_list)
+        ):
             token = tokenizer.ids2tokens(token_int)
-            timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3],
-                                                                   us_peak[:encoder_out_lens[i] * 3],
-                                                                   copy.copy(token))
-            text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
-                token, timestamp)
-            result_i = {"key": key[i], "text": text_postprocessed,
-                                "timestamp": time_stamp_postprocessed,
-                                }    
-            # ibest_writer["token"][key[i]] = " ".join(token)
-            ibest_writer["timestamp_list"][key[i]] = time_stamp_postprocessed
-            ibest_writer["timestamp_str"][key[i]] = timestamp_str
+            timestamp_str, timestamp = ts_prediction_lfr6_standard(
+                us_alpha[: encoder_out_lens[i] * 3],
+                us_peak[: encoder_out_lens[i] * 3],
+                copy.copy(token),
+            )
+            text_postprocessed, time_stamp_postprocessed, _ = (
+                postprocess_utils.sentence_postprocess(token, timestamp)
+            )
+            result_i = {
+                "key": key[i],
+                "text": text_postprocessed,
+                "timestamp": time_stamp_postprocessed,
+            }
             results.append(result_i)
-        return results, meta_data
\ No newline at end of file
+
+            if ibest_writer:
+                # ibest_writer["token"][key[i]] = " ".join(token)
+                ibest_writer["timestamp_list"][key[i]] = time_stamp_postprocessed
+                ibest_writer["timestamp_str"][key[i]] = timestamp_str
+
+        return results, meta_data

--
Gitblit v1.9.1