From db12d1c1d8ddbce39fd702dbc9ec4bcfbcb2003a Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期四, 16 三月 2023 10:50:10 +0800
Subject: [PATCH] Merge pull request #239 from alibaba-damo-academy/dev_wjm
---
funasr/bin/asr_inference_uniasr.py | 16 +++++++++++++++-
1 files changed, 15 insertions(+), 1 deletions(-)
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index cfec9a0..8b31fad 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -397,6 +397,19 @@
device = "cuda"
else:
device = "cpu"
+
+ if param_dict is not None and "decoding_model" in param_dict:
+ if param_dict["decoding_model"] == "fast":
+ decoding_ind = 0
+ decoding_mode = "model1"
+ elif param_dict["decoding_model"] == "normal":
+ decoding_ind = 0
+ decoding_mode = "model2"
+ elif param_dict["decoding_model"] == "offline":
+ decoding_ind = 1
+ decoding_mode = "model2"
+ else:
+ raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
# 1. Set random-seed
set_all_random_seed(seed)
@@ -433,6 +446,7 @@
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
+ **kwargs,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
@@ -492,7 +506,7 @@
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
- text_postprocessed = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
--
Gitblit v1.9.1