From 23e7ddebccd3b05cf7ef89809bcfe565ad6dfa1f Mon Sep 17 00:00:00 2001
From: majic31 <majic31@163.com>
Date: 星期二, 24 十二月 2024 10:00:14 +0800
Subject: [PATCH] Fix the variable name (#2328)

---
 funasr/utils/timestamp_tools.py |  101 ++++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 96 insertions(+), 5 deletions(-)

diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 831d773..995e179 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -29,13 +29,13 @@
 
 
 def ts_prediction_lfr6_standard(
-    us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
+    us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
 ):
     if not len(char_list):
         return "", []
     START_END_THRESHOLD = 5
-    MAX_TOKEN_DURATION = 12
-    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
+    MAX_TOKEN_DURATION = 12  #  3 times upsampled
+    TIME_RATE=10.0 * 6 / 1000 / upsample_rate
     if len(us_alphas.shape) == 2:
         alphas, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
     else:
@@ -84,7 +84,8 @@
         timestamp_list.append([_end * TIME_RATE, num_frames * TIME_RATE])
         new_char_list.append("<sil>")
     else:
-        timestamp_list[-1][1] = num_frames * TIME_RATE
+        if len(timestamp_list)>0:
+            timestamp_list[-1][1] = num_frames * TIME_RATE
     if vad_offset:  # add offset time in model with vad
         for i in range(len(timestamp_list)):
             timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
@@ -141,6 +142,8 @@
     )
     for punc_stamp_text in punc_stamp_text_list:
         punc_id, timestamp, text = punc_stamp_text
+        if sentence_start is None and timestamp is not None:
+            sentence_start = timestamp[0]
         # sentence_text += text if text is not None else ''
         if text is not None:
             if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
@@ -183,5 +186,93 @@
             sentence_text = ""
             sentence_text_seg = ""
             ts_list = []
-            sentence_start = sentence_end
+            sentence_start = None
+    return res
+
+
+def timestamp_sentence_en(
+    punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
+):
+    punc_list = [",", ".", "?", ","]
+    res = []
+    if text_postprocessed is None:
+        return res
+    if timestamp_postprocessed is None:
+        return res
+    if len(timestamp_postprocessed) == 0:
+        return res
+    if len(text_postprocessed) == 0:
+        return res
+
+    if punc_id_list is None or len(punc_id_list) == 0:
+        res.append(
+            {
+                "text": text_postprocessed.split(),
+                "start": timestamp_postprocessed[0][0],
+                "end": timestamp_postprocessed[-1][1],
+                "timestamp": timestamp_postprocessed,
+            }
+        )
+        return res
+    if len(punc_id_list) != len(timestamp_postprocessed):
+        logging.warning("length mismatch between punc and timestamp")
+    sentence_text = ""
+    sentence_text_seg = ""
+    ts_list = []
+    sentence_start = timestamp_postprocessed[0][0]
+    sentence_end = timestamp_postprocessed[0][1]
+    texts = text_postprocessed.split()
+    punc_stamp_text_list = list(
+        zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
+    )
+    is_sentence_start = True
+    for punc_stamp_text in punc_stamp_text_list:
+        punc_id, timestamp, text = punc_stamp_text
+        # sentence_text += text if text is not None else ''
+        if text is not None:
+            if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
+                sentence_text += " " + text
+            elif len(sentence_text) and (
+                "a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
+            ):
+                sentence_text += " " + text
+            else:
+                sentence_text += text
+            sentence_text_seg += text + " "
+        ts_list.append(timestamp)
+
+        punc_id = int(punc_id) if punc_id is not None else 1
+        sentence_end = timestamp[1] if timestamp is not None else sentence_end
+        sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text
+        if is_sentence_start:
+            sentence_start = timestamp[0] if timestamp is not None else sentence_start
+            is_sentence_start = False
+        if punc_id > 1:
+            is_sentence_start = True
+            sentence_text += punc_list[punc_id - 2]
+            sentence_text_seg = (
+                sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
+            )
+            if return_raw_text:
+                res.append(
+                    {
+                        "text": sentence_text,
+                        "start": sentence_start,
+                        "end": sentence_end,
+                        "timestamp": ts_list,
+                        "raw_text": sentence_text_seg,
+                    }
+                )
+            else:
+                res.append(
+                    {
+                        "text": sentence_text,
+                        "start": sentence_start,
+                        "end": sentence_end,
+                        "timestamp": ts_list,
+                    }
+                )
+            sentence_text = ""
+            sentence_text_seg = ""
+            ts_list = []
     return res

--
Gitblit v1.9.1