From a4d87a7fff1eccd192b9a3637ecb185b009c7977 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 三月 2023 12:44:55 +0800
Subject: [PATCH] Merge pull request #217 from alibaba-damo-academy/dev_wjm

---
 tests/test_sv_inference_pipeline.py                                                          |   47 +++++++++++
 tests/test_punctuation_pipeline.py                                                           |   43 ++++++++++
 tests/test_vad_inference_pipeline.py                                                         |   35 ++++++++
 tests/test_asr_inference_pipeline.py                                                         |    4 
 tests/test_lm_pipeline.py                                                                    |   25 ++++++
 egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py |    4 
 tests/test_asr_vad_punc_inference_pipeline.py                                                |   32 ++++++++
 7 files changed, 186 insertions(+), 4 deletions(-)

diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
index 02859c2..540e3cf 100644
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
@@ -5,7 +5,7 @@
 from modelscope.pipelines import pipeline
 from modelscope.utils.constant import Tasks
 
-inference_pipline = pipeline(
+inference_pipeline = pipeline(
     task=Tasks.punctuation,
     model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
     model_revision="v1.0.0",
@@ -17,7 +17,7 @@
 cache_out = []
 rec_result_all="outputs:"
 for vad in vads:
-    rec_result = inference_pipline(text_in=vad, cache=cache_out)
+    rec_result = inference_pipeline(text_in=vad, cache=cache_out)
     #print(rec_result)
     cache_out = rec_result['cache']
     rec_result_all += rec_result['text']
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index b3c5a24..70dbe89 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -451,8 +451,8 @@
 
     def test_uniasr_2pass_zhcn_16k_common_vocab8358_offline(self):
         inference_pipeline = pipeline(
-            task=Tasks.auto_speech_recognition,
-            model='damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
+            task=Tasks.,
+            model='damo/speech_UniASauto_speech_recognitionR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
         rec_result = inference_pipeline(
             audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav',
             param_dict={"decoding_model": "offline"})
diff --git a/tests/test_asr_vad_punc_inference_pipeline.py b/tests/test_asr_vad_punc_inference_pipeline.py
new file mode 100644
index 0000000..628b256
--- /dev/null
+++ b/tests/test_asr_vad_punc_inference_pipeline.py
@@ -0,0 +1,32 @@
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+class TestParaformerInferencePipelines(unittest.TestCase):
+    def test_funasr_path(self):
+        import funasr
+        import os
+        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
+
+    def test_inference_pipeline(self):
+        inference_pipeline = pipeline(
+            task=Tasks.auto_speech_recognition,
+            model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
+            model_revision="v1.2.1",
+            vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
+            vad_model_revision="v1.1.8",
+            punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
+            punc_model_revision="v1.1.6",
+            ngpu=1,
+        )
+        rec_result = inference_pipeline(
+            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+        logger.info("asr_vad_punc inference result: {0}".format(rec_result))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/test_lm_pipeline.py b/tests/test_lm_pipeline.py
new file mode 100644
index 0000000..3a5ec57
--- /dev/null
+++ b/tests/test_lm_pipeline.py
@@ -0,0 +1,25 @@
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+class TestTransformerInferencePipelines(unittest.TestCase):
+    def test_funasr_path(self):
+        import funasr
+        import os
+        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
+
+    def test_inference_pipeline(self):
+        inference_pipeline = pipeline(
+            task=Tasks.language_score_prediction,
+            model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
+        )
+        rec_result = inference_pipeline(text_in="hello 澶� 瀹� 濂� 鍛�")
+        logger.info("lm inference result: {0}".format(rec_result))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/test_punctuation_pipeline.py b/tests/test_punctuation_pipeline.py
new file mode 100644
index 0000000..52be9bb
--- /dev/null
+++ b/tests/test_punctuation_pipeline.py
@@ -0,0 +1,43 @@
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+class TestTransformerInferencePipelines(unittest.TestCase):
+    def test_funasr_path(self):
+        import funasr
+        import os
+        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
+
+    def test_inference_pipeline(self):
+        inference_pipeline = pipeline(
+            task=Tasks.punctuation,
+            model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
+            model_revision="v1.1.7",
+        )
+        inputs = "./egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt"
+        rec_result = inference_pipeline(text_in=inputs)
+        logger.info("punctuation inference result: {0}".format(rec_result))
+
+    def test_vadrealtime_inference_pipeline(self):
+        inference_pipeline = pipeline(
+            task=Tasks.punctuation,
+            model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
+            model_revision="v1.0.0",
+        )
+        inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+        vads = inputs.split("|")
+        cache_out = []
+        rec_result_all = "outputs:"
+        for vad in vads:
+            rec_result = inference_pipeline(text_in=vad, cache=cache_out)
+            cache_out = rec_result['cache']
+            rec_result_all += rec_result['text']
+        logger.info("punctuation inference result: {0}".format(rec_result_all))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/test_sv_inference_pipeline.py b/tests/test_sv_inference_pipeline.py
new file mode 100644
index 0000000..265f839
--- /dev/null
+++ b/tests/test_sv_inference_pipeline.py
@@ -0,0 +1,47 @@
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+
+class TestXVectorInferencePipelines(unittest.TestCase):
+    def test_funasr_path(self):
+        import funasr
+        import os
+        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
+
+    def test_inference_pipeline(self):
+        inference_sv_pipline = pipeline(
+            task=Tasks.speaker_verification,
+            model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
+        )
+        # 鎻愬彇涓嶅悓鍙ュ瓙鐨勮璇濅汉宓屽叆鐮�
+        rec_result = inference_sv_pipline(
+            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
+        enroll = rec_result["spk_embedding"]
+
+        rec_result = inference_sv_pipline(
+            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav')
+        same = rec_result["spk_embedding"]
+
+        rec_result = inference_sv_pipline(
+            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
+        different = rec_result["spk_embedding"]
+
+        # 瀵圭浉鍚岀殑璇磋瘽浜鸿绠椾綑寮︾浉浼煎害
+        sv_threshold = 0.9465
+        same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
+        same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
+        logger.info("Similarity: {}".format(same_cos))
+
+        # 瀵逛笉鍚岀殑璇磋瘽浜鸿绠椾綑寮︾浉浼煎害
+        diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
+        diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
+        logger.info("Similarity: {}".format(diff_cos))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/tests/test_vad_inference_pipeline.py b/tests/test_vad_inference_pipeline.py
new file mode 100644
index 0000000..d22f461
--- /dev/null
+++ b/tests/test_vad_inference_pipeline.py
@@ -0,0 +1,35 @@
+import unittest
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+class TestFSMNInferencePipelines(unittest.TestCase):
+    def test_funasr_path(self):
+        import funasr
+        import os
+        logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
+
+    def test_8k(self):
+        inference_pipeline = pipeline(
+            task=Tasks.voice_activity_detection,
+            model="damo/speech_fsmn_vad_zh-cn-8k-common",
+        )
+        rec_result = inference_pipeline(
+            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example_8k.wav')
+        logger.info("vad inference result: {0}".format(rec_result))
+
+    def test_16k(self):
+        inference_pipeline = pipeline(
+            task=Tasks.voice_activity_detection,
+            model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+        )
+        rec_result = inference_pipeline(
+            audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
+        logger.info("vad inference result: {0}".format(rec_result))
+
+
+if __name__ == '__main__':
+    unittest.main()

--
Gitblit v1.9.1