From 3e9319263835bd018abe2dcd59e029603b714022 Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期二, 11 六月 2024 11:52:26 +0800
Subject: [PATCH] english timestamp for valilla paraformer

---
 examples/industrial_data_pretraining/paraformer/demo.py |   13 ++++
 funasr/utils/timestamp_tools.py                         |   85 ++++++++++++++++++++++++++++
 funasr/auto/auto_model.py                               |   43 ++++++++++----
 3 files changed, 128 insertions(+), 13 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index eb7e72f..647669f 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -21,6 +21,19 @@
 print(res)
 
 
+""" call english model like below for detailed timestamps
+# choose english paraformer model first
+# iic/speech_paraformer_asr-en-16k-vocab4199-pytorch
+res = model.generate(
+    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav",
+    cache={},
+    pred_timestamp=True,
+    return_raw_text=True,
+    sentence_timestamp=True,
+    en_post_proc=True,
+)
+"""
+
 """ can not use currently
 from funasr import AutoFrontend
 
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 7b5a02f..fb81608 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -19,6 +19,7 @@
 from funasr.utils.load_utils import load_bytes
 from funasr.download.file import download_from_url
 from funasr.utils.timestamp_tools import timestamp_sentence
+from funasr.utils.timestamp_tools import timestamp_sentence_en
 from funasr.download.download_from_hub import download_model
 from funasr.utils.vad_utils import slice_padding_audio_samples
 from funasr.utils.vad_utils import merge_vad
@@ -321,7 +322,7 @@
             input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
         )
         end_vad = time.time()
-
+            
         #  FIX(gcf): concat the vad clips for sense vocie model for better aed
         if kwargs.get("merge_vad", False):
             for i in range(len(res)):
@@ -513,24 +514,40 @@
                                        and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                                        can predict timestamp, and speaker diarization relies on timestamps."
                         )
-                    sentence_list = timestamp_sentence(
-                        punc_res[0]["punc_array"],
-                        result["timestamp"],
-                        raw_text,
-                        return_raw_text=return_raw_text,
-                    )
+                    if kwargs.get("en_post_proc", False):
+                        sentence_list = timestamp_sentence_en(
+                            punc_res[0]["punc_array"],
+                            result["timestamp"],
+                            raw_text,
+                            return_raw_text=return_raw_text,
+                        )
+                    else:
+                        sentence_list = timestamp_sentence(
+                            punc_res[0]["punc_array"],
+                            result["timestamp"],
+                            raw_text,
+                            return_raw_text=return_raw_text,
+                        )
                 distribute_spk(sentence_list, sv_output)
                 result["sentence_info"] = sentence_list
             elif kwargs.get("sentence_timestamp", False):
                 if not len(result["text"].strip()):
                     sentence_list = []
                 else:
-                    sentence_list = timestamp_sentence(
-                        punc_res[0]["punc_array"],
-                        result["timestamp"],
-                        raw_text,
-                        return_raw_text=return_raw_text,
-                    )
+                    if kwargs.get("en_post_proc", False):
+                        sentence_list = timestamp_sentence_en(
+                            punc_res[0]["punc_array"],
+                            result["timestamp"],
+                            raw_text,
+                            return_raw_text=return_raw_text,
+                        )
+                    else:
+                        sentence_list = timestamp_sentence(
+                            punc_res[0]["punc_array"],
+                            result["timestamp"],
+                            raw_text,
+                            return_raw_text=return_raw_text,
+                        )
                 result["sentence_info"] = sentence_list
             if "spk_embedding" in result:
                 del result["spk_embedding"]
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index af61e5a..6abebe1 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -185,3 +185,88 @@
             ts_list = []
             sentence_start = sentence_end
     return res
+
+
+def timestamp_sentence_en(
+    punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
+):
+    punc_list = [",", ".", "?", ","]
+    res = []
+    if text_postprocessed is None:
+        return res
+    if timestamp_postprocessed is None:
+        return res
+    if len(timestamp_postprocessed) == 0:
+        return res
+    if len(text_postprocessed) == 0:
+        return res
+
+    if punc_id_list is None or len(punc_id_list) == 0:
+        res.append(
+            {
+                "text": text_postprocessed.split(),
+                "start": timestamp_postprocessed[0][0],
+                "end": timestamp_postprocessed[-1][1],
+                "timestamp": timestamp_postprocessed,
+            }
+        )
+        return res
+    if len(punc_id_list) != len(timestamp_postprocessed):
+        logging.warning("length mismatch between punc and timestamp")
+    sentence_text = ""
+    sentence_text_seg = ""
+    ts_list = []
+    sentence_start = timestamp_postprocessed[0][0]
+    sentence_end = timestamp_postprocessed[0][1]
+    texts = text_postprocessed.split()
+    punc_stamp_text_list = list(
+        zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
+    )
+    for punc_stamp_text in punc_stamp_text_list:
+        punc_id, timestamp, text = punc_stamp_text
+        # sentence_text += text if text is not None else ''
+        if text is not None:
+            if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
+                sentence_text += " " + text
+            elif len(sentence_text) and (
+                "a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
+            ):
+                sentence_text += " " + text
+            else:
+                sentence_text += text
+            sentence_text_seg += text + " "
+        ts_list.append(timestamp)
+
+        punc_id = int(punc_id) if punc_id is not None else 1
+        sentence_end = timestamp[1] if timestamp is not None else sentence_end
+        sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text
+        
+        if punc_id > 1:
+            sentence_text += punc_list[punc_id - 2]
+            sentence_text_seg = (
+                sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
+            )
+            if return_raw_text:
+                res.append(
+                    {
+                        "text": sentence_text,
+                        "start": sentence_start,
+                        "end": sentence_end,
+                        "timestamp": ts_list,
+                        "raw_text": sentence_text_seg,
+                    }
+                )
+            else:
+                res.append(
+                    {
+                        "text": sentence_text,
+                        "start": sentence_start,
+                        "end": sentence_end,
+                        "timestamp": ts_list,
+                    }
+                )
+            sentence_text = ""
+            sentence_text_seg = ""
+            ts_list = []
+            sentence_start = sentence_end
+    return res
\ No newline at end of file

--
Gitblit v1.9.1