From f11f53ec5c7bb09ebe315b2a39ad17ddab339759 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 10 三月 2023 18:48:16 +0800
Subject: [PATCH] update unittest
---
tests/test_punctuation_pipeline.py | 43 +++++++++++++++++++++++++++++++++++++++++++
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py | 4 ++--
2 files changed, 45 insertions(+), 2 deletions(-)
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
index 02859c2..540e3cf 100644
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
@@ -5,7 +5,7 @@
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
-inference_pipline = pipeline(
+inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
model_revision="v1.0.0",
@@ -17,7 +17,7 @@
cache_out = []
rec_result_all="outputs:"
for vad in vads:
- rec_result = inference_pipline(text_in=vad, cache=cache_out)
+ rec_result = inference_pipeline(text_in=vad, cache=cache_out)
#print(rec_result)
cache_out = rec_result['cache']
rec_result_all += rec_result['text']
diff --git a/tests/test_punctuation_pipeline.py b/tests/test_punctuation_pipeline.py
new file mode 100644
index 0000000..4610f01
--- /dev/null
+++ b/tests/test_punctuation_pipeline.py
@@ -0,0 +1,43 @@
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+class TestTransformerInferencePipelines(unittest.TestCase):
+ def test_funasr_path(self):
+ import funasr
+ import os
+ logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
+
+ def test_inference_pipeline(self):
+ inference_pipeline = pipeline(
+ task=Tasks.punctuation,
+ model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
+ model_revision="v1.1.7",
+ )
+ inputs = "./egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt"
+ rec_result = inference_pipeline(text_in=inputs)
+ logger.info("asr inference result: {0}".format(rec_result))
+
+ def test_vadrealtime_inference_pipeline(self):
+ inference_pipeline = pipeline(
+ task=Tasks.punctuation,
+ model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
+ model_revision="v1.0.0",
+ )
+ inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+ vads = inputs.split("|")
+ cache_out = []
+ rec_result_all = "outputs:"
+ for vad in vads:
+ rec_result = inference_pipeline(text_in=vad, cache=cache_out)
+ cache_out = rec_result['cache']
+ rec_result_all += rec_result['text']
+ logger.info("asr inference result: {0}".format(rec_result_all))
+
+
+if __name__ == '__main__':
+ unittest.main()
--
Gitblit v1.9.1