From 602fe75a1f0a8d64ccb6fc4d69ad510872fdfd13 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 三月 2023 20:30:40 +0800
Subject: [PATCH] rtf benchmark

---
 funasr/bin/asr_inference_paraformer_vad.py |   38 ++++++++++++++++++++++++++++++++------
 1 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index 2cd28cc..a0dc0aa 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -38,7 +38,6 @@
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
 from funasr.bin.punctuation_infer import Text2Punc
 from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text
 from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment
@@ -168,6 +167,11 @@
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
+
+    if param_dict is not None:
+        hotword_list_or_file = param_dict.get('hotword')
+    else:
+        hotword_list_or_file = None
     
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
@@ -207,6 +211,7 @@
         ngram_weight=ngram_weight,
         penalty=penalty,
         nbest=nbest,
+        hotword_list_or_file=hotword_list_or_file,
     )
     speech2text = Speech2Text(**speech2text_kwargs)
     text2punc = None
@@ -223,7 +228,19 @@
                  output_dir_v2: Optional[str] = None,
                  fs: dict = None,
                  param_dict: dict = None,
+                 **kwargs,
                  ):
+
+        hotword_list_or_file = None
+        if param_dict is not None:
+            hotword_list_or_file = param_dict.get('hotword')
+
+        if 'hotword' in kwargs:
+            hotword_list_or_file = kwargs['hotword']
+
+        if speech2text.hotword_list is None:
+            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
         # 3. Build data-iterator
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
@@ -241,6 +258,11 @@
             allow_variable_data_keys=allow_variable_data_keys,
             inference=True,
         )
+
+        if param_dict is not None:
+            use_timestamp = param_dict.get('use_timestamp', True)
+        else:
+            use_timestamp = True
         
         finish_count = 0
         file_count = 1
@@ -284,8 +306,10 @@
                 text, token, token_int = result[0], result[1], result[2]
                 time_stamp = None if len(result) < 4 else result[3]
                
-                
-                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+                if use_timestamp and time_stamp is not None:
+                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+                else:
+                    postprocessed_result = postprocess_utils.sentence_postprocess(token)
                 text_postprocessed = ""
                 time_stamp_postprocessed = ""
                 text_postprocessed_punc = postprocessed_result
@@ -293,9 +317,11 @@
                     text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
                                                                                postprocessed_result[1], \
                                                                                postprocessed_result[2]
-                    text_postprocessed_punc = ""
-                    if len(word_lists) > 0 and text2punc is not None:
-                        text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+                else:
+                    text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+                text_postprocessed_punc = text_postprocessed
+                if len(word_lists) > 0 and text2punc is not None:
+                    text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
 
                 
                 item = {'key': key, 'value': text_postprocessed_punc}

--
Gitblit v1.9.1