From 32905d8cdedd53dad26680b0bd41397aaf0e51ae Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 05 一月 2024 11:52:48 +0800
Subject: [PATCH] funasr1.0
---
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 20 +++++++++++++++++---
1 files changed, 17 insertions(+), 3 deletions(-)
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 71cf434..c4c558e 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -14,7 +14,8 @@
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, get_logger,
read_yaml)
-from .utils.postprocess_utils import sentence_postprocess
+from .utils.postprocess_utils import (sentence_postprocess,
+ sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
@@ -36,7 +37,6 @@
intra_op_num_threads: int = 4,
cache_dir: str = None
):
-
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
@@ -87,6 +87,10 @@
self.pred_bias = config['model_conf']['predictor_bias']
else:
self.pred_bias = 0
+ if "lang" in config:
+ self.language = config['lang']
+ else:
+ self.language = None
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -112,7 +116,10 @@
preds = self.decode(am_scores, valid_token_lens)
if us_peaks is None:
for pred in preds:
- pred = sentence_postprocess(pred)
+ if self.language == "en-bpe":
+ pred = sentence_postprocess_sentencepiece(pred)
+ else:
+ pred = sentence_postprocess(pred)
asr_res.append({'preds': pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
@@ -242,6 +249,13 @@
if not Path(model_dir).exists():
try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ except:
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
+ "\npip3 install -U modelscope\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+ try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
--
Gitblit v1.9.1