From a1fe3c635f47e941c2bb2a545ce0aface87fe041 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 13 三月 2023 17:26:20 +0800
Subject: [PATCH] update tp inference

---
 funasr/models/e2e_tp.py    |   21 +++++++++++++++++++++
 funasr/tasks/asr.py        |   14 ++------------
 funasr/bin/tp_inference.py |    2 +-
 3 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py
index 8c9c0f3..e374a22 100644
--- a/funasr/bin/tp_inference.py
+++ b/funasr/bin/tp_inference.py
@@ -18,7 +18,7 @@
 
 from funasr.fileio.datadir_writer import DatadirWriter
 from funasr.datasets.preprocessor import LMPreprocessor
-from funasr.tasks.asr import ASRTaskAligner_temp as ASRTask
+from funasr.tasks.asr import ASRTaskAligner as ASRTask
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index 8808008..9850051 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -41,6 +41,7 @@
             encoder: AbsEncoder,
             predictor: CifPredictorV3,
             predictor_bias: int = 0,
+            token_list=None,
     ):
         assert check_argument_types()
 
@@ -54,6 +55,7 @@
         self.predictor = predictor
         self.predictor_bias = predictor_bias
         self.criterion_pre = mae_loss()
+        self.token_list = token_list
     
     def forward(
             self,
@@ -152,3 +154,22 @@
                                                                                                encoder_out_mask,
                                                                                                token_num)
         return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+
+    def collect_feats(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+    ) -> Dict[str, torch.Tensor]:
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+            feats, feats_lengths = speech, speech_lengths
+        return {"feats": feats, "feats_lengths": feats_lengths}
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index fc2dbbc..36499a2 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -125,7 +125,7 @@
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,
         mfcca=MFCCA,
-        timestamp_predictor=TimestampPredictor,
+        timestamp_prediction=TimestampPredictor,
     ),
     type_check=AbsESPnetModel,
     default="asr",
@@ -1278,8 +1278,6 @@
             token_list = list(args.token_list)
         else:
             raise RuntimeError("token_list must be str or list")
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
 
         # 1. frontend
         if args.input_size is None:
@@ -1316,6 +1314,7 @@
             frontend=frontend,
             encoder=encoder,
             predictor=predictor,
+            token_list=token_list,
             **args.model_conf,
         )
 
@@ -1332,12 +1331,3 @@
     ) -> Tuple[str, ...]:
         retval = ("speech", "text")
         return retval
-
-
-class ASRTaskAligner_temp(ASRTaskParaformer):
-    @classmethod
-    def required_data_names(
-            cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        retval = ("speech", "text")
-        return retval
\ No newline at end of file

--
Gitblit v1.9.1