From 92dec9b3317fe1b687b29542348dd624f0e7398b Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 04 七月 2023 20:01:14 +0800
Subject: [PATCH] Merge pull request #703 from alibaba-damo-academy/dev_wjm
---
funasr/bin/diar_inference_launch.py | 15 +++------------
tests/test_asr_inference_pipeline.py | 16 ++++++++++++----
funasr/bin/asr_inference_launch.py | 5 +----
3 files changed, 16 insertions(+), 20 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index b64e6f0..8310791 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1367,10 +1367,7 @@
left_context=left_context,
right_context=right_context,
)
- speech2text = Speech2TextTransducer.from_pretrained(
- model_tag=model_tag,
- **speech2text_kwargs,
- )
+ speech2text = Speech2TextTransducer(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py
index 03c9659..c065137 100755
--- a/funasr/bin/diar_inference_launch.py
+++ b/funasr/bin/diar_inference_launch.py
@@ -92,10 +92,7 @@
embedding_node="resnet1_dense"
)
logging.info("speech2xvector_kwargs: {}".format(speech2xvector_kwargs))
- speech2xvector = Speech2Xvector.from_pretrained(
- model_tag=model_tag,
- **speech2xvector_kwargs,
- )
+ speech2xvector = Speech2Xvector(**speech2xvector_kwargs)
speech2xvector.sv_model.eval()
# 2b. Build speech2diar
@@ -109,10 +106,7 @@
dur_threshold=dur_threshold,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2DiarizationSOND.from_pretrained(
- model_tag=model_tag,
- **speech2diar_kwargs,
- )
+ speech2diar = Speech2DiarizationSOND(**speech2diar_kwargs)
speech2diar.diar_model.eval()
def output_results_str(results: dict, uttid: str):
@@ -257,10 +251,7 @@
dtype=dtype,
)
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
- speech2diar = Speech2DiarizationEEND.from_pretrained(
- model_tag=model_tag,
- **speech2diar_kwargs,
- )
+ speech2diar = Speech2DiarizationEEND(**speech2diar_kwargs)
speech2diar.diar_model.eval()
def output_results_str(results: dict, uttid: str):
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index 2b21acf..f68f29b 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -119,20 +119,28 @@
def test_paraformer_large_online_common(self):
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
- model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online')
+ model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
+ model_revision='v1.0.6',
+ update_model=False,
+ mode="paraformer_fake_streaming"
+ )
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
- assert rec_result["text"] == "娆㈣繋澶� 瀹舵潵 浣撻獙杈� 鎽╅櫌鎺� 鍑虹殑 璇煶璇� 鍒ā 鍨�"
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_online_common(self):
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
- model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online')
+ model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
+ model_revision='v1.0.6',
+ update_model=False,
+ mode="paraformer_fake_streaming"
+ )
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
- assert rec_result["text"] == "娆㈣繋 澶у鏉� 浣撻獙杈� 鎽╅櫌鎺� 鍑虹殑 璇煶璇� 鍒ā 鍨�"
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_tiny_commandword(self):
inference_pipeline = pipeline(
--
Gitblit v1.9.1