From 5ddad6db687ed45bc0b38cfc802ddfc3ab8c7f68 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 17 一月 2023 17:34:38 +0800
Subject: [PATCH] modelscope paraformer large long input

---
 funasr/bin/asr_inference_paraformer_vad_punc.py |   55 +++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 37 insertions(+), 18 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 265e054..2ee77a8 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -666,9 +666,10 @@
     vad_infer_config: Optional[str] = None,
     vad_model_file: Optional[str] = None,
     vad_cmvn_file: Optional[str] = None,
-    time_stamp_writer: bool = False,
+    time_stamp_writer: bool = True,
     punc_infer_config: Optional[str] = None,
     punc_model_file: Optional[str] = None,
+    outputs_dict: Optional[bool] = True,
     **kwargs,
 ):
     assert check_argument_types()
@@ -725,6 +726,11 @@
     speech2text = Speech2Text(**speech2text_kwargs)
     
     text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
+
+    if output_dir is not None:
+        writer = DatadirWriter(output_dir)
+        ibest_writer = writer[f"1best_recog"]
+        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
     
     def _forward(data_path_and_name_and_type,
                  raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@@ -756,6 +762,9 @@
         output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
         if output_path is not None:
             writer = DatadirWriter(output_path)
+            ibest_writer = writer[f"1best_recog"]
+            # ibest_writer["punc_dict"][""] = " ".join(punc_infer_config.punc_list)
+            # ibest_writer["token_list"][""] = " ".join(asr_train_config.token_list)
         else:
             writer = None
         
@@ -805,11 +814,10 @@
                 
                 # Create a directory: outdir/{n}best_recog
                 if writer is not None:
-                    ibest_writer = writer[f"1best_recog"]
-                    
                     # Write the result to each file
                     ibest_writer["token"][key] = " ".join(token)
-                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["vad"][key] = "{}".format(vadsegments)
                 
                 if text is not None:
                     postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
@@ -826,17 +834,22 @@
                         word_lists = None
                         text_postprocessed_punc_time_stamp = None
                         punc_id_list = None
-                    
+
                     item = {'key': key, 'value': text_postprocessed_punc_time_stamp, 'text': text_postprocessed,
-                            'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list}
+                            'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list, 'token': token}
+                    if outputs_dict:
+                        item = {'text_punc': text_postprocessed_punc, 'text': text_postprocessed,
+                                'punc_id': punc_id_list, 'token': token, 'time_stamp': time_stamp_postprocessed}
+                        item = {'key': key, 'value': item}
                     asr_result_list.append(item)
                     finish_count += 1
                     # asr_utils.print_progress(finish_count / file_count)
                     if writer is not None:
                         ibest_writer["text"][key] = text_postprocessed
-                        if time_stamp_writer and time_stamp_postprocessed is not None:
-                            ibest_writer["time_stamp"][key] = " ".join(
-                                ["-".join(map(str, ts)) for ts in time_stamp_postprocessed])
+                        ibest_writer["punc_id"][key] = "{}".format(punc_id_list)
+                        ibest_writer["text_with_punc"][key] = text_postprocessed_punc_time_stamp
+                        if time_stamp_postprocessed is not None:
+                            ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
                 
                 logging.info("decoding, utt: {}, predictions: {}, time_stamp: {}".format(key, text_postprocessed_punc,
                                                                                          time_stamp_postprocessed))
@@ -869,7 +882,6 @@
             punc_list[i] = "锛�"
         elif punc_list[i] == "銆�":
             period = i
-    
     preprocessor = CommonPreprocessor(
         train=False,
         token_type="word",
@@ -887,7 +899,8 @@
         cache_sent = []
         mini_sentences = split_to_mini_sentence(words, split_size)
         new_mini_sentence = ""
-        new_mini_sentence_punc = ""
+        new_mini_sentence_punc = []
+        cache_pop_trigger_limit = 200
         for mini_sentence_i in range(len(mini_sentences)):
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence = cache_sent + mini_sentence
@@ -904,24 +917,31 @@
             if indices.size()[0] != 1:
                 punctuations = torch.squeeze(indices)
             assert punctuations.size()[0] == len(mini_sentence)
-            
+
             # Search for the last Period/QuestionMark as cache
             if mini_sentence_i < len(mini_sentences) - 1:
                 sentenceEnd = -1
+                last_comma_index = -1
                 for i in range(len(punctuations) - 2, 1, -1):
                     if punc_list[punctuations[i]] == "銆�" or punc_list[punctuations[i]] == "锛�":
                         sentenceEnd = i
                         break
-                
+                    if last_comma_index < 0 and punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = period
                 cache_sent = mini_sentence[sentenceEnd + 1:]
                 mini_sentence = mini_sentence[0:sentenceEnd + 1]
                 punctuations = punctuations[0:sentenceEnd + 1]
-    
+
             # if len(punctuations) == 0:
             #    continue
-            
+
             punctuations_np = punctuations.cpu().numpy()
-            new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
+            new_mini_sentence_punc += [int(x) for x in punctuations_np]
             words_with_punc = []
             for i in range(len(mini_sentence)):
                 if i > 0:
@@ -931,9 +951,8 @@
                 if punc_list[punctuations[i]] != "_":
                     words_with_punc.append(punc_list[punctuations[i]])
             new_mini_sentence += "".join(words_with_punc)
-            
+
         return new_mini_sentence, new_mini_sentence_punc
-            
     return _forward
 
 def get_parser():

--
Gitblit v1.9.1