From a1fe3c635f47e941c2bb2a545ce0aface87fe041 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 13 三月 2023 17:26:20 +0800
Subject: [PATCH] update tp inference

---
 funasr/models/e2e_tp.py |   21 +++++++++++++++++++++
 1 files changed, 21 insertions(+), 0 deletions(-)

diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index 8808008..9850051 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -41,6 +41,7 @@
             encoder: AbsEncoder,
             predictor: CifPredictorV3,
             predictor_bias: int = 0,
+            token_list=None,
     ):
         assert check_argument_types()
 
@@ -54,6 +55,7 @@
         self.predictor = predictor
         self.predictor_bias = predictor_bias
         self.criterion_pre = mae_loss()
+        self.token_list = token_list
     
     def forward(
             self,
@@ -152,3 +154,22 @@
                                                                                                encoder_out_mask,
                                                                                                token_num)
         return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+
+    def collect_feats(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+    ) -> Dict[str, torch.Tensor]:
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+            feats, feats_lengths = speech, speech_lengths
+        return {"feats": feats, "feats_lengths": feats_lengths}

--
Gitblit v1.9.1