From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/python/http/client.py                            |   34 ++++++++
 funasr/bin/punc_infer.py                                        |   23 +++++
 funasr/runtime/python/http/README.md                            |   47 +++++++++++
 funasr/runtime/python/http/requirements.txt                     |    6 +
 funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py |    9 ++
 funasr/runtime/python/http/server.py                            |  107 ++++++++++++++++++++++++++
 funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py |    2 
 funasr/bin/asr_inference_launch.py                              |    8 +-
 funasr/export/export_model.py                                   |    2 
 9 files changed, 229 insertions(+), 9 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 2def98b..cdaaefc 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -415,7 +415,7 @@
                         ibest_writer["rtf"][key] = rtf_cur
 
                     if text is not None:
-                        if use_timestamp and timestamp is not None:
+                        if use_timestamp and timestamp is not None and len(timestamp):
                             postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
                         else:
                             postprocessed_result = postprocess_utils.sentence_postprocess(token)
@@ -427,7 +427,7 @@
                         else:
                             text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
                         item = {'key': key, 'value': text_postprocessed}
-                        if timestamp_postprocessed != "":
+                        if timestamp_postprocessed != "" or len(timestamp) == 0:
                             item['timestamp'] = timestamp_postprocessed
                         asr_result_list.append(item)
                         finish_count += 1
@@ -692,7 +692,7 @@
             text, token, token_int = result[0], result[1], result[2]
             time_stamp = result[4] if len(result[4]) > 0 else None
 
-            if use_timestamp and time_stamp is not None:
+            if use_timestamp and time_stamp is not None and len(time_stamp):
                 postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
             else:
                 postprocessed_result = postprocess_utils.sentence_postprocess(token)
@@ -717,7 +717,7 @@
             item = {'key': key, 'value': text_postprocessed_punc}
             if text_postprocessed != "":
                 item['text_postprocessed'] = text_postprocessed
-            if time_stamp_postprocessed != "":
+            if time_stamp_postprocessed != "" or len(time_stamp) == 0:
                 item['time_stamp'] = time_stamp_postprocessed
 
             item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index 7b61717..9efeb5b 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -117,12 +117,25 @@
             new_mini_sentence_punc += [int(x) for x in punctuations_np]
             words_with_punc = []
             for i in range(len(mini_sentence)):
+                if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+                    mini_sentence[i] = mini_sentence[i].capitalize()
+                if i == 0:
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
                 if i > 0:
                     if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                         mini_sentence[i] = " " + mini_sentence[i]
                 words_with_punc.append(mini_sentence[i])
                 if self.punc_list[punctuations[i]] != "_":
-                    words_with_punc.append(self.punc_list[punctuations[i]])
+                    punc_res = self.punc_list[punctuations[i]]
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        if punc_res == "锛�":
+                            punc_res = ","
+                        elif punc_res == "銆�":
+                            punc_res = "."
+                        elif punc_res == "锛�":
+                            punc_res = "?"
+                    words_with_punc.append(punc_res)
             new_mini_sentence += "".join(words_with_punc)
             # Add Period for the end of the sentence
             new_mini_sentence_out = new_mini_sentence
@@ -131,9 +144,15 @@
                 if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+                elif new_mini_sentence[-1] == ",":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
                     new_mini_sentence_out = new_mini_sentence + "銆�"
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+                    new_mini_sentence_out = new_mini_sentence + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
         return new_mini_sentence_out, new_mini_sentence_punc_out
 
 
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index e0a9313..6ab9408 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -254,7 +254,7 @@
             if not os.path.exists(quant_model_path):
                 onnx_model = onnx.load(model_path)
                 nodes = [n.name for n in onnx_model.graph.node]
-                nodes_to_exclude = [m for m in nodes if 'output' in m]
+                nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m  or 'bias_decoder' in m]
                 quantize_dynamic(
                     model_input=model_path,
                     model_output=quant_model_path,
diff --git a/funasr/runtime/python/http/README.md b/funasr/runtime/python/http/README.md
new file mode 100644
index 0000000..5b3fbb3
--- /dev/null
+++ b/funasr/runtime/python/http/README.md
@@ -0,0 +1,47 @@
+# Service with http-python
+
+## Server
+
+1. Install requirements
+
+```shell
+cd funasr/runtime/python/http
+pip install -r requirements.txt
+```
+
+2. Start server
+
+```shell
+python server.py --port 8000
+```
+
+More parameters:
+```shell
+python server.py \
+--host [host ip] \
+--port [server port] \
+--asr_model [asr model_name] \
+--punc_model [punc model_name] \
+--ngpu [0 or 1] \
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl] \
+--temp_dir [upload file temp dir] 
+```
+
+## Client
+
+```shell
+# get test audio file
+wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav
+python client.py --host=127.0.0.1 --port=8000 --audio_path=asr_example_zh.wav
+```
+
+More parameters:
+```shell
+python server.py \
+--host [sever ip] \
+--port [sever port] \
+--add_pun [add pun to result] \
+--audio_path [use audio path] 
+```
diff --git a/funasr/runtime/python/http/client.py b/funasr/runtime/python/http/client.py
new file mode 100644
index 0000000..09e9eea
--- /dev/null
+++ b/funasr/runtime/python/http/client.py
@@ -0,0 +1,34 @@
+import requests
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--host",
+                    type=str,
+                    default="127.0.0.1",
+                    required=False,
+                    help="sever ip")
+parser.add_argument("--port",
+                    type=int,
+                    default=8000,
+                    required=False,
+                    help="server port")
+parser.add_argument("--add_pun",
+                    type=int,
+                    default=1,
+                    required=False,
+                    help="add pun to result")
+parser.add_argument("--audio_path",
+                    type=str,
+                    default='asr_example_zh.wav',
+                    required=False,
+                    help="use audio path")
+args = parser.parse_args()
+
+
+url = f'http://{args.host}:{args.port}/recognition'
+data = {'add_pun': args.add_pun}
+headers = {}
+files = [('audio', ('file', open(args.audio_path, 'rb'), 'application/octet-stream'))]
+
+response = requests.post(url, headers=headers, data=data, files=files)
+print(response.text)
diff --git a/funasr/runtime/python/http/requirements.txt b/funasr/runtime/python/http/requirements.txt
new file mode 100644
index 0000000..bf55b9e
--- /dev/null
+++ b/funasr/runtime/python/http/requirements.txt
@@ -0,0 +1,6 @@
+modelscope>=1.8.4
+fastapi>=0.95.1
+ffmpeg-python
+aiofiles
+uvicorn
+requests
\ No newline at end of file
diff --git a/funasr/runtime/python/http/server.py b/funasr/runtime/python/http/server.py
new file mode 100644
index 0000000..283cf0a
--- /dev/null
+++ b/funasr/runtime/python/http/server.py
@@ -0,0 +1,107 @@
+import argparse
+import logging
+import os
+import random
+import time
+
+import aiofiles
+import ffmpeg
+import uvicorn
+from fastapi import FastAPI, File, UploadFile, Body
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--host",
+                    type=str,
+                    default="0.0.0.0",
+                    required=False,
+                    help="host ip, localhost, 0.0.0.0")
+parser.add_argument("--port",
+                    type=int,
+                    default=8000,
+                    required=False,
+                    help="server port")
+parser.add_argument("--asr_model",
+                    type=str,
+                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+                    help="model from modelscope")
+parser.add_argument("--punc_model",
+                    type=str,
+                    default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
+                    help="model from modelscope")
+parser.add_argument("--ngpu",
+                    type=int,
+                    default=1,
+                    help="0 for cpu, 1 for gpu")
+parser.add_argument("--ncpu",
+                    type=int,
+                    default=4,
+                    help="cpu cores")
+parser.add_argument("--certfile",
+                    type=str,
+                    default=None,
+                    required=False,
+                    help="certfile for ssl")
+parser.add_argument("--keyfile",
+                    type=str,
+                    default=None,
+                    required=False,
+                    help="keyfile for ssl")
+parser.add_argument("--temp_dir",
+                    type=str,
+                    default="temp_dir/",
+                    required=False,
+                    help="temp dir")
+args = parser.parse_args()
+
+os.makedirs(args.temp_dir, exist_ok=True)
+
+print("model loading")
+# asr
+inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
+                                  model=args.asr_model,
+                                  ngpu=args.ngpu,
+                                  ncpu=args.ncpu,
+                                  model_revision=None)
+print(f'loaded asr models.')
+
+if args.punc_model != "":
+    inference_pipeline_punc = pipeline(task=Tasks.punctuation,
+                                       model=args.punc_model,
+                                       model_revision="v1.0.2",
+                                       ngpu=args.ngpu,
+                                       ncpu=args.ncpu)
+    print(f'loaded pun models.')
+else:
+    inference_pipeline_punc = None
+
+app = FastAPI(title="FunASR")
+
+
+@app.post("/recognition")
+async def api_recognition(audio: UploadFile = File(..., description="audio file"),
+                          add_pun: int = Body(1, description="add punctuation", embed=True)):
+    suffix = audio.filename.split('.')[-1]
+    audio_path = f'{args.temp_dir}/{int(time.time() * 1000)}_{random.randint(100, 999)}.{suffix}'
+    async with aiofiles.open(audio_path, 'wb') as out_file:
+        content = await audio.read()
+        await out_file.write(content)
+    audio_bytes, _ = (
+        ffmpeg.input(audio_path, threads=0)
+        .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
+        .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
+    )
+    rec_result = inference_pipeline_asr(audio_in=audio_bytes, param_dict={})
+    if add_pun:
+        rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
+    ret = {"results": rec_result['text'], "code": 0}
+    return ret
+
+
+if __name__ == '__main__':
+    uvicorn.run(app, host=args.host, port=args.port, ssl_keyfile=args.keyfile, ssl_certfile=args.certfile)
diff --git a/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
index 984c0d6..9da3817 100644
--- a/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
+++ b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -5,7 +5,7 @@
 model = ContextualParaformer(model_dir, batch_size=1)
 
 wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
-hotwords = '闅忔満鐑瘝 鍚勭鐑瘝 榄旀惌 闃块噷宸村反'
+hotwords = '闅忔満鐑瘝 鍚勭鐑瘝 榄旀惌 闃块噷宸村反 浠�'
 
 result = model(wav_path, hotwords)
 print(result)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 884def9..8727896 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -333,7 +333,14 @@
         hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
         # hotwords.append('<s>')
         def word_map(word):
-            return torch.tensor([self.vocab[i] for i in word])
+            hotwords = []
+            for c in word:
+                if c not in self.vocab.keys():
+                    hotwords.append(8403)
+                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
+                else:
+                    hotwords.append(self.vocab[c])
+            return torch.tensor(hotwords)
         hotword_int = [word_map(i) for i in hotwords]
         # import pdb; pdb.set_trace()
         hotword_int.append(torch.tensor([1]))

--
Gitblit v1.9.1