From 20aa07268a7fafaaab7762b488615af32a0e82b4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 11 六月 2024 14:02:27 +0800
Subject: [PATCH] update with main (#1800)

---
 funasr/models/paraformer/model.py |   24 ++++++++++++++++++++----
 1 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 0d9bb2b..85967af 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -4,6 +4,7 @@
 #  MIT License  (https://opensource.org/licenses/MIT)
 
 import time
+import copy
 import torch
 import logging
 from torch.cuda.amp import autocast
@@ -21,6 +22,7 @@
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 
@@ -452,6 +454,7 @@
         is_use_lm = (
             kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
         )
+        pred_timestamp = kwargs.get("pred_timestamp", False)
         if self.beam_search is None and (is_use_lm or is_use_ctc):
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
@@ -506,6 +509,7 @@
             predictor_outs[2],
             predictor_outs[3],
         )
+        
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
@@ -564,10 +568,22 @@
                     # Change integer-ids to tokens
                     token = tokenizer.ids2tokens(token_int)
                     text_postprocessed = tokenizer.tokens2text(token)
-                    if not hasattr(tokenizer, "bpemodel"):
-                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-
-                    result_i = {"key": key[i], "text": text_postprocessed}
+                    
+                    if pred_timestamp:
+                        timestamp_str, timestamp = ts_prediction_lfr6_standard(
+                            pre_peak_index[i],
+                            alphas[i],
+                            copy.copy(token),
+                            vad_offset=kwargs.get("begin_time", 0),
+                            upsample_rate=1,
+                        )
+                        if not hasattr(tokenizer, "bpemodel"):
+                            text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
+                        result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
+                    else:
+                        if not hasattr(tokenizer, "bpemodel"):
+                            text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                        result_i = {"key": key[i], "text": text_postprocessed}
 
                     if ibest_writer is not None:
                         ibest_writer["token"][key[i]] = " ".join(token)

--
Gitblit v1.9.1