游雁
2023-09-13 33d3d2084403fd34b79c835d2f2fe04f6cd8f738
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
5个文件已修改
4个文件已添加
238 ■■■■■ 已修改文件
funasr/bin/asr_inference_launch.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_infer.py 23 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/http/README.md 47 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/http/client.py 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/http/requirements.txt 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/http/server.py 107 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py
@@ -415,7 +415,7 @@
                        ibest_writer["rtf"][key] = rtf_cur
                    if text is not None:
                        if use_timestamp and timestamp is not None:
                        if use_timestamp and timestamp is not None and len(timestamp):
                            postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
                        else:
                            postprocessed_result = postprocess_utils.sentence_postprocess(token)
@@ -427,7 +427,7 @@
                        else:
                            text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
                        item = {'key': key, 'value': text_postprocessed}
                        if timestamp_postprocessed != "":
                        if timestamp_postprocessed != "" or len(timestamp) == 0:
                            item['timestamp'] = timestamp_postprocessed
                        asr_result_list.append(item)
                        finish_count += 1
@@ -692,7 +692,7 @@
            text, token, token_int = result[0], result[1], result[2]
            time_stamp = result[4] if len(result[4]) > 0 else None
            if use_timestamp and time_stamp is not None:
            if use_timestamp and time_stamp is not None and len(time_stamp):
                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
            else:
                postprocessed_result = postprocess_utils.sentence_postprocess(token)
@@ -717,7 +717,7 @@
            item = {'key': key, 'value': text_postprocessed_punc}
            if text_postprocessed != "":
                item['text_postprocessed'] = text_postprocessed
            if time_stamp_postprocessed != "":
            if time_stamp_postprocessed != "" or len(time_stamp) == 0:
                item['time_stamp'] = time_stamp_postprocessed
            item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
funasr/bin/punc_infer.py
@@ -117,12 +117,25 @@
            new_mini_sentence_punc += [int(x) for x in punctuations_np]
            words_with_punc = []
            for i in range(len(mini_sentence)):
                if (i==0 or self.punc_list[punctuations[i-1]] == "。" or self.punc_list[punctuations[i-1]] == "?") and len(mini_sentence[i][0].encode()) == 1:
                    mini_sentence[i] = mini_sentence[i].capitalize()
                if i == 0:
                    if len(mini_sentence[i][0].encode()) == 1:
                        mini_sentence[i] = " " + mini_sentence[i]
                if i > 0:
                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                        mini_sentence[i] = " " + mini_sentence[i]
                words_with_punc.append(mini_sentence[i])
                if self.punc_list[punctuations[i]] != "_":
                    words_with_punc.append(self.punc_list[punctuations[i]])
                    punc_res = self.punc_list[punctuations[i]]
                    if len(mini_sentence[i][0].encode()) == 1:
                        if punc_res == ",":
                            punc_res = ","
                        elif punc_res == "。":
                            punc_res = "."
                        elif punc_res == "?":
                            punc_res = "?"
                    words_with_punc.append(punc_res)
            new_mini_sentence += "".join(words_with_punc)
            # Add Period for the end of the sentence
            new_mini_sentence_out = new_mini_sentence
@@ -131,9 +144,15 @@
                if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
                elif new_mini_sentence[-1] == ",":
                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0:
                    new_mini_sentence_out = new_mini_sentence + "。"
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                    new_mini_sentence_out = new_mini_sentence + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
        return new_mini_sentence_out, new_mini_sentence_punc_out
funasr/export/export_model.py
@@ -254,7 +254,7 @@
            if not os.path.exists(quant_model_path):
                onnx_model = onnx.load(model_path)
                nodes = [n.name for n in onnx_model.graph.node]
                nodes_to_exclude = [m for m in nodes if 'output' in m]
                nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m  or 'bias_decoder' in m]
                quantize_dynamic(
                    model_input=model_path,
                    model_output=quant_model_path,
funasr/runtime/python/http/README.md
New file
@@ -0,0 +1,47 @@
# Service with http-python
## Server
1. Install requirements
```shell
cd funasr/runtime/python/http
pip install -r requirements.txt
```
2. Start server
```shell
python server.py --port 8000
```
More parameters:
```shell
python server.py \
--host [host ip] \
--port [server port] \
--asr_model [asr model_name] \
--punc_model [punc model_name] \
--ngpu [0 or 1] \
--ncpu [1 or 4] \
--certfile [path of certfile for ssl] \
--keyfile [path of keyfile for ssl] \
--temp_dir [upload file temp dir]
```
## Client
```shell
# get test audio file
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav
python client.py --host=127.0.0.1 --port=8000 --audio_path=asr_example_zh.wav
```
More parameters:
```shell
python server.py \
--host [sever ip] \
--port [sever port] \
--add_pun [add pun to result] \
--audio_path [use audio path]
```
funasr/runtime/python/http/client.py
New file
@@ -0,0 +1,34 @@
import requests
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
                    default="127.0.0.1",
                    required=False,
                    help="sever ip")
parser.add_argument("--port",
                    type=int,
                    default=8000,
                    required=False,
                    help="server port")
parser.add_argument("--add_pun",
                    type=int,
                    default=1,
                    required=False,
                    help="add pun to result")
parser.add_argument("--audio_path",
                    type=str,
                    default='asr_example_zh.wav',
                    required=False,
                    help="use audio path")
args = parser.parse_args()
url = f'http://{args.host}:{args.port}/recognition'
data = {'add_pun': args.add_pun}
headers = {}
files = [('audio', ('file', open(args.audio_path, 'rb'), 'application/octet-stream'))]
response = requests.post(url, headers=headers, data=data, files=files)
print(response.text)
funasr/runtime/python/http/requirements.txt
New file
@@ -0,0 +1,6 @@
modelscope>=1.8.4
fastapi>=0.95.1
ffmpeg-python
aiofiles
uvicorn
requests
funasr/runtime/python/http/server.py
New file
@@ -0,0 +1,107 @@
import argparse
import logging
import os
import random
import time
import aiofiles
import ffmpeg
import uvicorn
from fastapi import FastAPI, File, UploadFile, Body
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
                    default="0.0.0.0",
                    required=False,
                    help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
                    type=int,
                    default=8000,
                    required=False,
                    help="server port")
parser.add_argument("--asr_model",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    help="model from modelscope")
parser.add_argument("--punc_model",
                    type=str,
                    default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
                    help="model from modelscope")
parser.add_argument("--ngpu",
                    type=int,
                    default=1,
                    help="0 for cpu, 1 for gpu")
parser.add_argument("--ncpu",
                    type=int,
                    default=4,
                    help="cpu cores")
parser.add_argument("--certfile",
                    type=str,
                    default=None,
                    required=False,
                    help="certfile for ssl")
parser.add_argument("--keyfile",
                    type=str,
                    default=None,
                    required=False,
                    help="keyfile for ssl")
parser.add_argument("--temp_dir",
                    type=str,
                    default="temp_dir/",
                    required=False,
                    help="temp dir")
args = parser.parse_args()
os.makedirs(args.temp_dir, exist_ok=True)
print("model loading")
# asr
inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
                                  model=args.asr_model,
                                  ngpu=args.ngpu,
                                  ncpu=args.ncpu,
                                  model_revision=None)
print(f'loaded asr models.')
if args.punc_model != "":
    inference_pipeline_punc = pipeline(task=Tasks.punctuation,
                                       model=args.punc_model,
                                       model_revision="v1.0.2",
                                       ngpu=args.ngpu,
                                       ncpu=args.ncpu)
    print(f'loaded pun models.')
else:
    inference_pipeline_punc = None
app = FastAPI(title="FunASR")
@app.post("/recognition")
async def api_recognition(audio: UploadFile = File(..., description="audio file"),
                          add_pun: int = Body(1, description="add punctuation", embed=True)):
    suffix = audio.filename.split('.')[-1]
    audio_path = f'{args.temp_dir}/{int(time.time() * 1000)}_{random.randint(100, 999)}.{suffix}'
    async with aiofiles.open(audio_path, 'wb') as out_file:
        content = await audio.read()
        await out_file.write(content)
    audio_bytes, _ = (
        ffmpeg.input(audio_path, threads=0)
        .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
        .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
    )
    rec_result = inference_pipeline_asr(audio_in=audio_bytes, param_dict={})
    if add_pun:
        rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
    ret = {"results": rec_result['text'], "code": 0}
    return ret
if __name__ == '__main__':
    uvicorn.run(app, host=args.host, port=args.port, ssl_keyfile=args.keyfile, ssl_certfile=args.certfile)
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -5,7 +5,7 @@
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
hotwords = '随机热词 各种热词 魔搭 阿里巴巴 仏'
result = model(wav_path, hotwords)
print(result)
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -333,7 +333,14 @@
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        # hotwords.append('<s>')
        def word_map(word):
            return torch.tensor([self.vocab[i] for i in word])
            hotwords = []
            for c in word:
                if c not in self.vocab.keys():
                    hotwords.append(8403)
                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
                else:
                    hotwords.append(self.vocab[c])
            return torch.tensor(hotwords)
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))