From 741d089eb9dd9be7b6e2cabbd40fc0a784eb38f3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 21 二月 2024 16:28:58 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge

---
 funasr/models/paraformer/model.py                             |    1 
 funasr/utils/timestamp_tools.py                               |   25 ++++++++----
 funasr/auto/auto_model.py                                     |   34 +++++++++-------
 examples/industrial_data_pretraining/seaco_paraformer/demo.py |    7 ++-
 4 files changed, 40 insertions(+), 27 deletions(-)

diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 804acdd..a44c649 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -11,15 +11,16 @@
                   vad_model_revision="v2.0.4",
                   punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                   punc_model_revision="v2.0.4",
-                  spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
-                  spk_model_revision="v2.0.2",
+                  # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
+                  # spk_model_revision="v2.0.2",
                   )
 
 
 # example1
 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
                      hotword='杈炬懇闄� 榄旀惌',
-                     # preset_spk_num=2,
+                     # return_raw_text=True,     # return raw text recognition results splited by space of equal length with timestamp
+                     # preset_spk_num=2,         # preset speaker num for speaker cluster model
                      # sentence_timestamp=True,  # return sentence level information when spk_model is not given
                     )
 print(res)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 78e47cc..e5faa2a 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -379,12 +379,14 @@
                             result[k] = restored_data[j][k]
                         else:
                             result[k] += restored_data[j][k]
-                            
+            
+            return_raw_text = kwargs.get('return_raw_text', False)            
             # step.3 compute punc model
             if self.punc_model is not None:
                 self.punc_kwargs.update(cfg)
                 punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
                 raw_text = copy.copy(result["text"])
+                if return_raw_text: result['raw_text'] = raw_text
                 result["text"] = punc_res[0]["text"]
             else:
                 raw_text = None
@@ -403,26 +405,28 @@
                     for res, vadsegment in zip(restored_data, vadsegments):
                         if 'timestamp' not in res:
                             logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
-                                and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
-                                can predict timestamp, and speaker diarization relies on timestamps.")
-                        sentence_list.append({"start": vadsegment[0],\
-                                                "end": vadsegment[1],
-                                                "sentence": res['text'],
-                                                "timestamp": res['timestamp']})
+                                           and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
+                                           can predict timestamp, and speaker diarization relies on timestamps.")
+                        sentence_list.append({"start": vadsegment[0],
+                                              "end": vadsegment[1],
+                                              "sentence": res['text'],
+                                              "timestamp": res['timestamp']})
                 elif self.spk_mode == 'punc_segment':
                     if 'timestamp' not in result:
                         logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
-                            and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
-                            can predict timestamp, and speaker diarization relies on timestamps.")
-                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
-                                                        result['timestamp'], \
-                                                        raw_text)
+                                       and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
+                                       can predict timestamp, and speaker diarization relies on timestamps.")
+                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'],
+                                                       result['timestamp'],
+                                                       raw_text,
+                                                       return_raw_text=return_raw_text)
                 distribute_spk(sentence_list, sv_output)
                 result['sentence_info'] = sentence_list
             elif kwargs.get("sentence_timestamp", False):
-                sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
-                                                        result['timestamp'], \
-                                                        raw_text)
+                sentence_list = timestamp_sentence(punc_res[0]['punc_array'],
+                                                   result['timestamp'],
+                                                   raw_text,
+                                                   return_raw_text=return_raw_text)
                 result['sentence_info'] = sentence_list
             if "spk_embedding" in result: del result['spk_embedding']
                     
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index cf31cdb..729b8f5 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -537,7 +537,6 @@
                     
                     result_i = {"key": key[i], "text": text_postprocessed}
 
-                    
                     if ibest_writer is not None:
                         ibest_writer["token"][key[i]] = " ".join(token)
                         # ibest_writer["text"][key[i]] = text
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 63f179a..32f0f84 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -98,7 +98,7 @@
     return res_txt, res
 
 
-def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed):
+def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False):
     punc_list = ['锛�', '銆�', '锛�', '銆�']
     res = []
     if text_postprocessed is None:
@@ -142,15 +142,24 @@
 
         punc_id = int(punc_id) if punc_id is not None else 1
         sentence_end = timestamp[1] if timestamp is not None else sentence_end
-
+        sentence_text_seg = sentence_text_seg[:-1] if sentence_text_seg[-1] == ' ' else sentence_text_seg
         if punc_id > 1:
             sentence_text += punc_list[punc_id - 2]
-            res.append({
-                'text': sentence_text,
-                "start": sentence_start,
-                "end": sentence_end,
-                "timestamp": ts_list
-            })
+            if return_raw_text:
+                res.append({
+                    'text': sentence_text,
+                    "start": sentence_start,
+                    "end": sentence_end,
+                    "timestamp": ts_list,
+                    'raw_text': sentence_text_seg,
+                })
+            else:
+                res.append({
+                    'text': sentence_text,
+                    "start": sentence_start,
+                    "end": sentence_end,
+                    "timestamp": ts_list,
+                })
             sentence_text = ''
             sentence_text_seg = ''
             ts_list = []

--
Gitblit v1.9.1