From 5563b28a74d2058c1d7c0c79f816b9cc5eb5295a Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 13 四月 2023 10:20:17 +0800
Subject: [PATCH] Merge pull request #345 from alibaba-damo-academy/dev_tmp
---
tests/test_punctuation_pipeline.py | 8 +++-----
tests/test_asr_inference_pipeline.py | 2 ++
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index b3c5a24..2f2f11d 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -43,6 +43,7 @@
rec_result = inference_pipeline(
audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "姣忎竴澶╅兘瑕佸揩涔愬枖"
def test_paraformer(self):
inference_pipeline = pipeline(
@@ -51,6 +52,7 @@
rec_result = inference_pipeline(
audio_in='https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "姣忎竴澶╅兘瑕佸揩涔愬枖"
class TestMfccaInferencePipelines(unittest.TestCase):
diff --git a/tests/test_punctuation_pipeline.py b/tests/test_punctuation_pipeline.py
index 52be9bb..e582042 100644
--- a/tests/test_punctuation_pipeline.py
+++ b/tests/test_punctuation_pipeline.py
@@ -26,16 +26,14 @@
inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
- model_revision="v1.0.0",
)
inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
vads = inputs.split("|")
- cache_out = []
rec_result_all = "outputs:"
+ param_dict = {"cache": []}
for vad in vads:
- rec_result = inference_pipeline(text_in=vad, cache=cache_out)
- cache_out = rec_result['cache']
- rec_result_all += rec_result['text']
+ rec_result = inference_pipeline(text_in=vad, param_dict=param_dict)
+ rec_result_all += rec_result["text"]
logger.info("punctuation inference result: {0}".format(rec_result_all))
--
Gitblit v1.9.1