From 09ff7d4516128bfe1db8a81ca6de0d89ea55d88c Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 23 二月 2023 16:28:05 +0800
Subject: [PATCH] fix uniasr decoding bug
---
funasr/bin/asr_inference_uniasr.py | 25 +++++++++++++------------
funasr/bin/asr_inference_uniasr_vad.py | 25 +++++++++++++------------
2 files changed, 26 insertions(+), 24 deletions(-)
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index c50bf17..8b31fad 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -398,6 +398,19 @@
else:
device = "cpu"
+ if param_dict is not None and "decoding_model" in param_dict:
+ if param_dict["decoding_model"] == "fast":
+ decoding_ind = 0
+ decoding_mode = "model1"
+ elif param_dict["decoding_model"] == "normal":
+ decoding_ind = 0
+ decoding_mode = "model2"
+ elif param_dict["decoding_model"] == "offline":
+ decoding_ind = 1
+ decoding_mode = "model2"
+ else:
+ raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
+
# 1. Set random-seed
set_all_random_seed(seed)
@@ -440,18 +453,6 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- if param_dict is not None and "decoding_model" in param_dict:
- if param_dict["decoding_model"] == "fast":
- speech2text.decoding_ind = 0
- speech2text.decoding_mode = "model1"
- elif param_dict["decoding_model"] == "normal":
- speech2text.decoding_ind = 0
- speech2text.decoding_mode = "model2"
- elif param_dict["decoding_model"] == "offline":
- speech2text.decoding_ind = 1
- speech2text.decoding_mode = "model2"
- else:
- raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index ac3b4b6..e5815df 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -398,6 +398,19 @@
else:
device = "cpu"
+ if param_dict is not None and "decoding_model" in param_dict:
+ if param_dict["decoding_model"] == "fast":
+ decoding_ind = 0
+ decoding_mode = "model1"
+ elif param_dict["decoding_model"] == "normal":
+ decoding_ind = 0
+ decoding_mode = "model2"
+ elif param_dict["decoding_model"] == "offline":
+ decoding_ind = 1
+ decoding_mode = "model2"
+ else:
+ raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
+
# 1. Set random-seed
set_all_random_seed(seed)
@@ -440,18 +453,6 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- if param_dict is not None and "decoding_model" in param_dict:
- if param_dict["decoding_model"] == "fast":
- speech2text.decoding_ind = 0
- speech2text.decoding_mode = "model1"
- elif param_dict["decoding_model"] == "normal":
- speech2text.decoding_ind = 0
- speech2text.decoding_mode = "model2"
- elif param_dict["decoding_model"] == "offline":
- speech2text.decoding_ind = 1
- speech2text.decoding_mode = "model2"
- else:
- raise NotImplementedError("unsupported decoding model {}".format(param_dict["decoding_model"]))
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
--
Gitblit v1.9.1