From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky

---
 funasr/bin/asr_inference_paraformer_vad.py |   21 +++++++++++++++++++--
 1 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index dbb2719..1548f9f 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -38,7 +38,6 @@
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
 from funasr.bin.punctuation_infer import Text2Punc
 from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text
 from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment
@@ -168,6 +167,11 @@
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
+
+    if param_dict is not None:
+        hotword_list_or_file = param_dict.get('hotword')
+    else:
+        hotword_list_or_file = None
     
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
@@ -207,6 +211,7 @@
         ngram_weight=ngram_weight,
         penalty=penalty,
         nbest=nbest,
+        hotword_list_or_file=hotword_list_or_file,
     )
     speech2text = Speech2Text(**speech2text_kwargs)
     text2punc = None
@@ -223,7 +228,19 @@
                  output_dir_v2: Optional[str] = None,
                  fs: dict = None,
                  param_dict: dict = None,
+                 **kwargs,
                  ):
+
+        hotword_list_or_file = None
+        if param_dict is not None:
+            hotword_list_or_file = param_dict.get('hotword')
+
+        if 'hotword' in kwargs:
+            hotword_list_or_file = kwargs['hotword']
+
+        if speech2text.hotword_list is None:
+            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
         # 3. Build data-iterator
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, torch.Tensor):
@@ -321,7 +338,7 @@
                     ibest_writer["token"][key] = " ".join(token)
                     ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                     ibest_writer["vad"][key] = "{}".format(vadsegments)
-                    ibest_writer["text"][key] = text_postprocessed
+                    ibest_writer["text"][key] = " ".join(word_lists)
                     ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                     if time_stamp_postprocessed is not None:
                         ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)

--
Gitblit v1.9.1