From 440e3956fc9d507ea66a2e72f3fe8d27fb77099c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 二月 2023 11:49:00 +0800
Subject: [PATCH] Merge pull request #99 from alibaba-damo-academy/dev_lzr
---
funasr/bin/asr_inference_paraformer.py | 45 ++++++++++++++++++----
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md | 19 +++++++++
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py | 21 ++++++++++
3 files changed, 76 insertions(+), 9 deletions(-)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
new file mode 100644
index 0000000..49c0aeb
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/README.md
@@ -0,0 +1,19 @@
+# ModelScope Model
+
+## How to infer using a pretrained Paraformer-large Model
+
+### Inference
+
+You can use the pretrain model for inference directly.
+
+- Setting parameters in `infer.py`
+ - <strong>audio_in:</strong> # Support wav, url, bytes, and parsed audio format.
+ - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
+ - <strong>batch_size:</strong> # Set batch size in inference.
+ - <strong>param_dict:</strong> # Set the hotword list in inference.
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py
new file mode 100644
index 0000000..78fb8f1
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py
@@ -0,0 +1,21 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+
+if __name__ == '__main__':
+ param_dict = dict()
+ param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
+
+ audio_in = "//isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav"
+ output_dir = None
+ batch_size = 1
+
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
+ output_dir=output_dir,
+ batch_size=batch_size,
+ param_dict=param_dict)
+
+ rec_result = inference_pipeline(audio_in=audio_in)
+ print(rec_result)
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 6c5acfc..be35e78 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -6,6 +6,8 @@
import copy
import os
import codecs
+import tempfile
+import requests
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -175,10 +177,24 @@
self.converter = converter
self.tokenizer = tokenizer
- # 6. [Optional] Build hotword list from file or str
+ # 6. [Optional] Build hotword list from str, local file or url
+ # for None
if hotword_list_or_file is None:
self.hotword_list = None
+ # for text str input
+ elif not os.path.exists(hotword_list_or_file) and not hotword_list_or_file.startswith('http'):
+ logging.info("Attempting to parse hotwords as str...")
+ self.hotword_list = []
+ hotword_str_list = []
+ for hw in hotword_list_or_file.strip().split():
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Hotword list: {}.".format(hotword_str_list))
+ # for local txt inputs
elif os.path.exists(hotword_list_or_file):
+ logging.info("Attempting to parse hotwords from local txt...")
self.hotword_list = []
hotword_str_list = []
with codecs.open(hotword_list_or_file, 'r') as fin:
@@ -186,20 +202,31 @@
hw = line.strip()
hotword_str_list.append(hw)
self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([1])
+ self.hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Initialized hotword list from file: {}, hotword list: {}."
.format(hotword_list_or_file, hotword_str_list))
+ # for url, download and generate txt
else:
- logging.info("Attempting to parse hotwords as str...")
+ logging.info("Attempting to parse hotwords from url...")
+ work_dir = tempfile.TemporaryDirectory().name
+ if not os.path.exists(work_dir):
+ os.makedirs(work_dir)
+ text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+ local_file = requests.get(hotword_list_or_file)
+ open(text_file_path, "wb").write(local_file.content)
+ hotword_list_or_file = text_file_path
self.hotword_list = []
hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- self.hotword_list.append([1])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
+ with codecs.open(hotword_list_or_file, 'r') as fin:
+ for line in fin.readlines():
+ hw = line.strip()
+ hotword_str_list.append(hw)
+ self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ self.hotword_list.append([self.asr_model.sos])
+ hotword_str_list.append('<s>')
+ logging.info("Initialized hotword list from file: {}, hotword list: {}."
+ .format(hotword_list_or_file, hotword_str_list))
is_use_lm = lm_weight != 0.0 and lm_file is not None
--
Gitblit v1.9.1