游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
runtime/python/http/server.py
@@ -6,13 +6,13 @@
import aiofiles
import ffmpeg
import uvicorn
from fastapi import FastAPI, File, UploadFile, Body
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from fastapi import FastAPI, File, UploadFile
from modelscope.utils.logger import get_logger
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
from funasr import AutoModel
logger = get_logger(log_level=logging.INFO)
logger.setLevel(logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
@@ -27,27 +27,43 @@
                    help="server port")
parser.add_argument("--asr_model",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    help="offline asr model from modelscope")
                    default="paraformer-zh",
                    help="asr model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
parser.add_argument("--asr_model_revision",
                    type=str,
                    default="v2.0.4",
                    help="")
parser.add_argument("--vad_model",
                    type=str,
                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                    help="vad model from modelscope")
                    default="fsmn-vad",
                    help="vad model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
parser.add_argument("--vad_model_revision",
                    type=str,
                    default="v2.0.4",
                    help="")
parser.add_argument("--punc_model",
                    type=str,
                    default="damo/punc_ct-transformer_cn-en-common-vocab471067-large",
                    help="punc model from modelscope")
                    default="ct-punc-c",
                    help="model from https://github.com/alibaba-damo-academy/FunASR?tab=readme-ov-file#model-zoo")
parser.add_argument("--punc_model_revision",
                    type=str,
                    default="v2.0.4",
                    help="")
parser.add_argument("--ngpu",
                    type=int,
                    default=1,
                    help="0 for cpu, 1 for gpu")
parser.add_argument("--device",
                    type=str,
                    default="cuda",
                    help="cuda, cpu")
parser.add_argument("--ncpu",
                    type=int,
                    default=4,
                    help="cpu cores")
parser.add_argument("--hotword_path",
                    type=str,
                    default=None,
                    default='hotwords.txt',
                    help="hot word txt path, only the hot word model works")
parser.add_argument("--certfile",
                    type=str,
@@ -65,58 +81,74 @@
                    required=False,
                    help="temp dir")
args = parser.parse_args()
print("-----------  Configuration Arguments -----------")
logger.info("-----------  Configuration Arguments -----------")
for arg, value in vars(args).items():
    print("%s: %s" % (arg, value))
print("------------------------------------------------")
    logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
os.makedirs(args.temp_dir, exist_ok=True)
print("model loading")
param_dict = {}
if args.hotword_path is not None and os.path.exists(args.hotword_path):
    param_dict['hotword'] = args.hotword_path
# asr
inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
                                  model=args.asr_model,
                                  vad_model=args.vad_model,
                                  ngpu=args.ngpu,
                                  ncpu=args.ncpu,
                                  param_dict=param_dict)
print(f'loaded asr models.')
if args.punc_model != "":
    inference_pipeline_punc = pipeline(task=Tasks.punctuation,
                                       model=args.punc_model,
                                       ngpu=args.ngpu,
                                       ncpu=args.ncpu)
    print(f'loaded pun models.')
else:
    inference_pipeline_punc = None
logger.info("model loading")
# load funasr model
model = AutoModel(model=args.asr_model,
                  model_revision=args.asr_model_revision,
                  vad_model=args.vad_model,
                  vad_model_revision=args.vad_model_revision,
                  punc_model=args.punc_model,
                  punc_model_revision=args.punc_model_revision,
                  ngpu=args.ngpu,
                  ncpu=args.ncpu,
                  device=args.device,
                  disable_pbar=True,
                  disable_log=True)
logger.info("loaded models!")
app = FastAPI(title="FunASR")
param_dict = {"sentence_timestamp": True, "batch_size_s": 300}
if args.hotword_path is not None and os.path.exists(args.hotword_path):
    with open(args.hotword_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        lines = [line.strip() for line in lines]
    hotword = ' '.join(lines)
    logger.info(f'热词:{hotword}')
    param_dict['hotword'] = hotword
@app.post("/recognition")
async def api_recognition(audio: UploadFile = File(..., description="audio file"),
                          add_pun: int = Body(1, description="add punctuation", embed=True)):
async def api_recognition(audio: UploadFile = File(..., description="audio file")):
    suffix = audio.filename.split('.')[-1]
    audio_path = f'{args.temp_dir}/{str(uuid.uuid1())}.{suffix}'
    async with aiofiles.open(audio_path, 'wb') as out_file:
        content = await audio.read()
        await out_file.write(content)
    audio_bytes, _ = (
        ffmpeg.input(audio_path, threads=0)
        .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
        .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
    )
    rec_result = inference_pipeline_asr(audio_in=audio_bytes, param_dict={})
    if add_pun:
        rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
    ret = {"results": rec_result['text'], "code": 0}
    print(ret)
    return ret
    try:
        audio_bytes, _ = (
            ffmpeg.input(audio_path, threads=0)
            .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
            .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
        )
    except Exception as e:
        logger.error(f'读取音频文件发生错误,错误信息:{e}')
        return {"msg": "读取音频文件发生错误", "code": 1}
    rec_results = model.generate(input=audio_bytes, is_final=True, **param_dict)
    # 结果为空
    if len(rec_results) == 0:
        return {"text": "", "sentences": [], "code": 0}
    elif len(rec_results) == 1:
        # 解析识别结果
        rec_result = rec_results[0]
        text = rec_result['text']
        sentences = []
        for sentence in rec_result['sentence_info']:
            # 每句话的时间戳
            sentences.append({'text': sentence['text'], 'start': sentence['start'], 'end': sentence['start']})
        ret = {"text": text, "sentences": sentences, "code": 0}
        logger.info(f'识别结果:{ret}')
        return ret
    else:
        logger.info(f'识别结果:{rec_results}')
        return {"msg": "未知错误", "code": -1}
if __name__ == '__main__':