From cf8e000a84e888495dcf30c4dbfecea1ee7ab4e2 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 07 八月 2023 16:13:37 +0800
Subject: [PATCH] Merge pull request #807 from alibaba-damo-academy/dev_wjm

---
 funasr/utils/timestamp_tools.py |   48 +++++++++++++-----------------------------------
 1 files changed, 13 insertions(+), 35 deletions(-)

diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 489d317..5787f1d 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -1,14 +1,10 @@
-from itertools import zip_longest
-
 import torch
-import copy
 import codecs
 import logging
-import edit_distance
 import argparse
-import pdb
 import numpy as np
-from typing import Any, List, Tuple, Union
+import edit_distance
+from itertools import zip_longest
 
 
 def ts_prediction_lfr6_standard(us_alphas, 
@@ -36,7 +32,14 @@
     # so treat the frames between two peaks as the duration of the former token
     fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
     num_peak = len(fire_place)
-    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
+    if num_peak != len(char_list) + 1:
+        logging.warning("length mismatch, result might be incorrect.")
+        logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
+    if num_peak > len(char_list) + 1:
+        fire_place = fire_place[:len(char_list) - 1]
+    elif num_peak < len(char_list) + 1:
+        char_list = char_list[:num_peak + 1]
+    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
     # begin silence
     if fire_place[0] > START_END_THRESHOLD:
         # char_list.insert(0, '<sil>')
@@ -80,6 +83,7 @@
 
 
 def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
+    punc_list = ['锛�', '銆�', '锛�', '銆�']
     res = []
     if text_postprocessed is None:
         return res
@@ -124,34 +128,8 @@
         punc_id = int(punc_id) if punc_id is not None else 1
         sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
 
-        if punc_id == 2:
-            sentence_text += ','
-            res.append({
-                'text': sentence_text,
-                "start": sentence_start,
-                "end": sentence_end,
-                "text_seg": sentence_text_seg,
-                "ts_list": ts_list
-            })
-            sentence_text = ''
-            sentence_text_seg = ''
-            ts_list = []
-            sentence_start = sentence_end
-        elif punc_id == 3:
-            sentence_text += '.'
-            res.append({
-                'text': sentence_text,
-                "start": sentence_start,
-                "end": sentence_end,
-                "text_seg": sentence_text_seg,
-                "ts_list": ts_list
-            })
-            sentence_text = ''
-            sentence_text_seg = ''
-            ts_list = []
-            sentence_start = sentence_end
-        elif punc_id == 4:
-            sentence_text += '?'
+        if punc_id > 1:
+            sentence_text += punc_list[punc_id - 2]
             res.append({
                 'text': sentence_text,
                 "start": sentence_start,

--
Gitblit v1.9.1