夜雨飘零
2024-02-02 85c08383831ea2b7cdf4c6f863f71b20b95b6782
support funasr 1.0 (#1346)

* support funasr 1.0

* update docs
4个文件已修改
1个文件已添加
118 ■■■■ 已修改文件
runtime/python/http/README.md 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/http/client.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/http/hotwords.txt 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/http/requirements.txt 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/http/server.py 102 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/http/README.md
@@ -23,6 +23,7 @@
--asr_model [asr model_name] \
--vad_model [vad model_name] \
--punc_model [punc model_name] \
--device [cuda or cpu] \
--ngpu [0 or 1] \
--ncpu [1 or 4] \
--hotword_path [path of hot word txt] \
@@ -44,7 +45,6 @@
python server.py \
--host [sever ip] \
--port [sever port] \
--add_pun [add pun to result] \
--audio_path [use audio path] 
```
runtime/python/http/client.py
@@ -14,11 +14,6 @@
                    default=8000,
                    required=False,
                    help="server port")
parser.add_argument("--add_pun",
                    type=int,
                    default=1,
                    required=False,
                    help="add pun to result")
parser.add_argument("--audio_path",
                    type=str,
                    default='asr_example_zh.wav',
@@ -32,9 +27,8 @@
url = f'http://{args.host}:{args.port}/recognition'
data = {'add_pun': args.add_pun}
headers = {}
files = [('audio', (os.path.basename(args.audio_path), open(args.audio_path, 'rb'), 'application/octet-stream'))]
response = requests.post(url, headers=headers, data=data, files=files)
response = requests.post(url, headers=headers, files=files)
print(response.text)
runtime/python/http/hotwords.txt
New file
@@ -0,0 +1,2 @@
阿里巴巴
通义实验室
runtime/python/http/requirements.txt
@@ -1,6 +1,6 @@
modelscope>=1.8.4
modelscope>=1.11.1
funasr>=1.0.5
fastapi>=0.95.1
ffmpeg-python
aiofiles
uvicorn
requests
runtime/python/http/server.py
@@ -4,15 +4,14 @@
import uuid
import aiofiles
import ffmpeg
import uvicorn
from fastapi import FastAPI, File, UploadFile, Body
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from fastapi import FastAPI, File, UploadFile
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
from funasr import AutoModel
logger = get_logger(log_level=logging.INFO)
logger.setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
@@ -27,27 +26,43 @@
                    help="server port")
parser.add_argument("--asr_model",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    help="offline asr model from modelscope")
                    default="paraformer-zh",
                    help="asr model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
parser.add_argument("--asr_model_revision",
                    type=str,
                    default="v2.0.4",
                    help="")
parser.add_argument("--vad_model",
                    type=str,
                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                    help="vad model from modelscope")
                    default="fsmn-vad",
                    help="vad model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
parser.add_argument("--vad_model_revision",
                    type=str,
                    default="v2.0.4",
                    help="")
parser.add_argument("--punc_model",
                    type=str,
                    default="damo/punc_ct-transformer_cn-en-common-vocab471067-large",
                    help="punc model from modelscope")
                    default="ct-punc-c",
                    help="model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
parser.add_argument("--punc_model_revision",
                    type=str,
                    default="v2.0.4",
                    help="")
parser.add_argument("--ngpu",
                    type=int,
                    default=1,
                    help="0 for cpu, 1 for gpu")
parser.add_argument("--device",
                    type=str,
                    default="cuda",
                    help="cuda, cpu")
parser.add_argument("--ncpu",
                    type=int,
                    default=4,
                    help="cpu cores")
parser.add_argument("--hotword_path",
                    type=str,
                    default=None,
                    default='hotwords.txt',
                    help="hot word txt path, only the hot word model works")
parser.add_argument("--certfile",
                    type=str,
@@ -65,57 +80,50 @@
                    required=False,
                    help="temp dir")
args = parser.parse_args()
print("-----------  Configuration Arguments -----------")
logger.info("-----------  Configuration Arguments -----------")
for arg, value in vars(args).items():
    print("%s: %s" % (arg, value))
print("------------------------------------------------")
    logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
os.makedirs(args.temp_dir, exist_ok=True)
print("model loading")
param_dict = {}
if args.hotword_path is not None and os.path.exists(args.hotword_path):
    param_dict['hotword'] = args.hotword_path
# asr
inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
                                  model=args.asr_model,
logger.info("model loading")
# load funasr model
model = AutoModel(model=args.asr_model,
                  model_revision=args.asr_model_revision,
                                  vad_model=args.vad_model,
                  vad_model_revision=args.vad_model_revision,
                  punc_model=args.punc_model,
                  punc_model_revision=args.punc_model_revision,
                                  ngpu=args.ngpu,
                                  ncpu=args.ncpu,
                                  param_dict=param_dict)
print(f'loaded asr models.')
if args.punc_model != "":
    inference_pipeline_punc = pipeline(task=Tasks.punctuation,
                                       model=args.punc_model,
                                       ngpu=args.ngpu,
                                       ncpu=args.ncpu)
    print(f'loaded pun models.')
else:
    inference_pipeline_punc = None
                  device=args.device,
                  disable_pbar=True,
                  disable_log=True)
logger.info("loaded models!")
app = FastAPI(title="FunASR")
param_dict = {}
if args.hotword_path is not None and os.path.exists(args.hotword_path):
    with open(args.hotword_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        lines = [line.strip() for line in lines]
    hotword = ' '.join(lines)
    logger.info(f'热词:{hotword}')
    param_dict['hotword'] = hotword
@app.post("/recognition")
async def api_recognition(audio: UploadFile = File(..., description="audio file"),
                          add_pun: int = Body(1, description="add punctuation", embed=True)):
async def api_recognition(audio: UploadFile = File(..., description="audio file")):
    suffix = audio.filename.split('.')[-1]
    audio_path = f'{args.temp_dir}/{str(uuid.uuid1())}.{suffix}'
    async with aiofiles.open(audio_path, 'wb') as out_file:
        content = await audio.read()
        await out_file.write(content)
    audio_bytes, _ = (
        ffmpeg.input(audio_path, threads=0)
        .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
        .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
    )
    rec_result = inference_pipeline_asr(audio_in=audio_bytes, param_dict={})
    if add_pun:
        rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
    ret = {"results": rec_result['text'], "code": 0}
    print(ret)
    rec_result = model.generate(input=audio_path, batch_size_s=300, **param_dict)
    ret = {"result": rec_result[0]['text'], "code": 0}
    logger.info(f'识别结果:{ret}')
    return ret