From 2bd82419482b9a074c8c464d5b65f997632964d3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 18 九月 2023 10:44:25 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/python/http/server.py |   32 ++++++++++++++++++++++++--------
 1 files changed, 24 insertions(+), 8 deletions(-)

diff --git a/funasr/runtime/python/http/server.py b/funasr/runtime/python/http/server.py
index 283cf0a..19d3193 100644
--- a/funasr/runtime/python/http/server.py
+++ b/funasr/runtime/python/http/server.py
@@ -1,8 +1,7 @@
 import argparse
 import logging
 import os
-import random
-import time
+import uuid
 
 import aiofiles
 import ffmpeg
@@ -29,11 +28,15 @@
 parser.add_argument("--asr_model",
                     type=str,
                     default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-                    help="model from modelscope")
+                    help="offline asr model from modelscope")
+parser.add_argument("--vad_model",
+                    type=str,
+                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                    help="vad model from modelscope")
 parser.add_argument("--punc_model",
                     type=str,
-                    default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
-                    help="model from modelscope")
+                    default="damo/punc_ct-transformer_cn-en-common-vocab471067-large",
+                    help="punc model from modelscope")
 parser.add_argument("--ngpu",
                     type=int,
                     default=1,
@@ -42,6 +45,10 @@
                     type=int,
                     default=4,
                     help="cpu cores")
+parser.add_argument("--hotword_path",
+                    type=str,
+                    default=None,
+                    help="hot word txt path, only the hot word model works")
 parser.add_argument("--certfile",
                     type=str,
                     default=None,
@@ -58,22 +65,30 @@
                     required=False,
                     help="temp dir")
 args = parser.parse_args()
+print("-----------  Configuration Arguments -----------")
+for arg, value in vars(args).items():
+    print("%s: %s" % (arg, value))
+print("------------------------------------------------")
+
 
 os.makedirs(args.temp_dir, exist_ok=True)
 
 print("model loading")
+param_dict = {}
+if args.hotword_path is not None and os.path.exists(args.hotword_path):
+    param_dict['hotword'] = args.hotword_path
 # asr
 inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
                                   model=args.asr_model,
+                                  vad_model=args.vad_model,
                                   ngpu=args.ngpu,
                                   ncpu=args.ncpu,
-                                  model_revision=None)
+                                  param_dict=param_dict)
 print(f'loaded asr models.')
 
 if args.punc_model != "":
     inference_pipeline_punc = pipeline(task=Tasks.punctuation,
                                        model=args.punc_model,
-                                       model_revision="v1.0.2",
                                        ngpu=args.ngpu,
                                        ncpu=args.ncpu)
     print(f'loaded pun models.')
@@ -87,7 +102,7 @@
 async def api_recognition(audio: UploadFile = File(..., description="audio file"),
                           add_pun: int = Body(1, description="add punctuation", embed=True)):
     suffix = audio.filename.split('.')[-1]
-    audio_path = f'{args.temp_dir}/{int(time.time() * 1000)}_{random.randint(100, 999)}.{suffix}'
+    audio_path = f'{args.temp_dir}/{str(uuid.uuid1())}.{suffix}'
     async with aiofiles.open(audio_path, 'wb') as out_file:
         content = await audio.read()
         await out_file.write(content)
@@ -100,6 +115,7 @@
     if add_pun:
         rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
     ret = {"results": rec_result['text'], "code": 0}
+    print(ret)
     return ret
 
 

--
Gitblit v1.9.1