From 602fe75a1f0a8d64ccb6fc4d69ad510872fdfd13 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 三月 2023 20:30:40 +0800
Subject: [PATCH] rtf benchmark
---
funasr/bin/asr_inference_paraformer_vad.py | 38 ++++++++++++++++++++++++++++++++------
1 files changed, 32 insertions(+), 6 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index 2cd28cc..a0dc0aa 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -38,7 +38,6 @@
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.bin.punctuation_infer import Text2Punc
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment
@@ -168,6 +167,11 @@
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
+
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ else:
+ hotword_list_or_file = None
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
@@ -207,6 +211,7 @@
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
text2punc = None
@@ -223,7 +228,19 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+
+ if 'hotword' in kwargs:
+ hotword_list_or_file = kwargs['hotword']
+
+ if speech2text.hotword_list is None:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
@@ -241,6 +258,11 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
finish_count = 0
file_count = 1
@@ -284,8 +306,10 @@
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
-
- postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ if use_timestamp and time_stamp is not None:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
text_postprocessed = ""
time_stamp_postprocessed = ""
text_postprocessed_punc = postprocessed_result
@@ -293,9 +317,11 @@
text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
postprocessed_result[1], \
postprocessed_result[2]
- text_postprocessed_punc = ""
- if len(word_lists) > 0 and text2punc is not None:
- text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+ text_postprocessed_punc = text_postprocessed
+ if len(word_lists) > 0 and text2punc is not None:
+ text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
item = {'key': key, 'value': text_postprocessed_punc}
--
Gitblit v1.9.1