From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/utils/timestamp_tools.py | 91 ++++++++++++++++++++++++++++++++++++++++++++-
1 files changed, 88 insertions(+), 3 deletions(-)
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 831d773..6abebe1 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -29,13 +29,13 @@
def ts_prediction_lfr6_standard(
- us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
+ us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
- MAX_TOKEN_DURATION = 12
- TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
+ MAX_TOKEN_DURATION = 12 # 3 times upsampled
+ TIME_RATE=10.0 * 6 / 1000 / upsample_rate
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
@@ -185,3 +185,88 @@
ts_list = []
sentence_start = sentence_end
return res
+
+
+def timestamp_sentence_en(
+ punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
+):
+ punc_list = [",", ".", "?", ","]
+ res = []
+ if text_postprocessed is None:
+ return res
+ if timestamp_postprocessed is None:
+ return res
+ if len(timestamp_postprocessed) == 0:
+ return res
+ if len(text_postprocessed) == 0:
+ return res
+
+ if punc_id_list is None or len(punc_id_list) == 0:
+ res.append(
+ {
+ "text": text_postprocessed.split(),
+ "start": timestamp_postprocessed[0][0],
+ "end": timestamp_postprocessed[-1][1],
+ "timestamp": timestamp_postprocessed,
+ }
+ )
+ return res
+ if len(punc_id_list) != len(timestamp_postprocessed):
+ logging.warning("length mismatch between punc and timestamp")
+ sentence_text = ""
+ sentence_text_seg = ""
+ ts_list = []
+ sentence_start = timestamp_postprocessed[0][0]
+ sentence_end = timestamp_postprocessed[0][1]
+ texts = text_postprocessed.split()
+ punc_stamp_text_list = list(
+ zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
+ )
+ for punc_stamp_text in punc_stamp_text_list:
+ punc_id, timestamp, text = punc_stamp_text
+ # sentence_text += text if text is not None else ''
+ if text is not None:
+ if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
+ sentence_text += " " + text
+ elif len(sentence_text) and (
+ "a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
+ ):
+ sentence_text += " " + text
+ else:
+ sentence_text += text
+ sentence_text_seg += text + " "
+ ts_list.append(timestamp)
+
+ punc_id = int(punc_id) if punc_id is not None else 1
+ sentence_end = timestamp[1] if timestamp is not None else sentence_end
+ sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text
+
+ if punc_id > 1:
+ sentence_text += punc_list[punc_id - 2]
+ sentence_text_seg = (
+ sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
+ )
+ if return_raw_text:
+ res.append(
+ {
+ "text": sentence_text,
+ "start": sentence_start,
+ "end": sentence_end,
+ "timestamp": ts_list,
+ "raw_text": sentence_text_seg,
+ }
+ )
+ else:
+ res.append(
+ {
+ "text": sentence_text,
+ "start": sentence_start,
+ "end": sentence_end,
+ "timestamp": ts_list,
+ }
+ )
+ sentence_text = ""
+ sentence_text_seg = ""
+ ts_list = []
+ sentence_start = sentence_end
+ return res
\ No newline at end of file
--
Gitblit v1.9.1