From 0856ea2ebdcb976db6e786de5cd79fae3d35cd4c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 20 二月 2023 18:18:35 +0800
Subject: [PATCH] Merge pull request #136 from alibaba-damo-academy/dev_cmz
---
funasr/bin/asr_inference_uniasr_vad.py | 12 ++++++++++++
1 files changed, 12 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index 0a5824c..de32dcf 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -439,6 +439,18 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ if param_dict is not None and "decoding_model" in param_dict:
+ if param_dict["decoding_model"] == "fast":
+ speech2text.decoding_ind = 0
+ speech2text.decoding_mode = "model1"
+ elif param_dict["decoding_model"] == "normal":
+ speech2text.decoding_ind = 0
+ speech2text.decoding_mode = "model2"
+ elif param_dict["decoding_model"] == "offline":
+ speech2text.decoding_ind = 1
+ speech2text.decoding_mode = "model2"
+ else:
+ raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
--
Gitblit v1.9.1