From ee1eefff68e25f2e7674616be34518b07d8135c3 Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期四, 09 十一月 2023 13:09:28 +0800
Subject: [PATCH] Merge pull request #1075 from alibaba-damo-academy/dev_lzr_en
---
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 12 ++++++++++--
1 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 7b13654..c4c558e 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -14,7 +14,8 @@
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, get_logger,
read_yaml)
-from .utils.postprocess_utils import sentence_postprocess
+from .utils.postprocess_utils import (sentence_postprocess,
+ sentence_postprocess_sentencepiece)
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
@@ -86,6 +87,10 @@
self.pred_bias = config['model_conf']['predictor_bias']
else:
self.pred_bias = 0
+ if "lang" in config:
+ self.language = config['lang']
+ else:
+ self.language = None
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -111,7 +116,10 @@
preds = self.decode(am_scores, valid_token_lens)
if us_peaks is None:
for pred in preds:
- pred = sentence_postprocess(pred)
+ if self.language == "en-bpe":
+ pred = sentence_postprocess_sentencepiece(pred)
+ else:
+ pred = sentence_postprocess(pred)
asr_res.append({'preds': pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
--
Gitblit v1.9.1