游雁
2024-06-06 783a051f6586d392bec6e1494607f79635b71741
Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed
merge
38个文件已修改
31个文件已添加
4238 ■■■■■ 已修改文件
README.md 52 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README_zh.md 48 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/images/wechat.png 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_frontend.py 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/espnet_samplers.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/download/name_maps_from_hub.py 22 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/adaptor.py 63 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/export_meta.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/decoder.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/model.py 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/encoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer_ds.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/README.md 72 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/asr_example.mp3 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/asr_example.wav 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/example.py 70 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/funasr_api.py 96 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/funasr_core.py 230 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/funasr_stream.py 72 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/funasr_api/funasr_tools.py 84 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/CMakeLists.txt 142 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/CMakeLists.txt 23 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/asr_sessions.h 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/connection.cpp 196 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/connection.hpp 104 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/file_parse.cpp 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/file_parse.hpp 234 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/funasr-http-main.cpp 523 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/funasr-http-main.hpp 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/header.hpp 27 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/io_context_pool.cpp 66 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/io_context_pool.hpp 59 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/model-decoder.cpp 119 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/model-decoder.h 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/reply.cpp 245 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/reply.hpp 64 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/server.cpp 113 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/bin/server.hpp 71 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/readme.md 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/readme_zh.md 61 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/http/requirements_install.md 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/CMakeLists.txt 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/CMakeLists.txt 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp 15 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/funasr-onnx-offline.cpp 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/audio.h 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/com-define.h 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/funasrruntime.h 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/model.h 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/offline-stream.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/CMakeLists.txt 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/audio.cpp 106 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 182 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/offline-stream.cpp 29 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-torch.cpp 415 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-torch.h 96 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.cpp 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.h 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/precomp.h 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 41 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/CMakeLists.txt 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/CMakeLists.txt 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/funasr-wss-server.cpp 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -2,8 +2,9 @@
([简体中文](./README_zh.md)|English)
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
[//]: # (# FunASR: A Fundamental End-to-End Speech Recognition Toolkit)
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=FunASR🤠&text2=💖%20A%20Fundamental%20End-to-End%20Speech%20Recognition%20Toolkit&width=800&height=210)](https://github.com/Akshay090/svg-banners)
[![PyPI](https://img.shields.io/pypi/v/funasr)](https://pypi.org/project/funasr/)
@@ -34,6 +35,9 @@
- 2024/03/05:Added support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
- 2024/03/05: Offline File Transcription Service 4.4, Offline File Transcription Service of English 1.5,Real-time Transcription Service 1.9 released,docker image supports ARM64 platform, update modelscope;([docs](runtime/readme.md))
- 2024/01/30:funasr-1.0 has been released ([docs](https://github.com/alibaba-damo-academy/FunASR/discussions/1319))
<details><summary>Full Changelog</summary>
- 2024/01/30:emotion recognition models are new supported. [model link](https://www.modelscope.cn/models/iic/emotion2vec_base_finetuned/summary), modified from [repo](https://github.com/ddlBoJack/emotion2vec).
- 2024/01/25: Offline File Transcription Service 4.2, Offline File Transcription Service of English 1.3 released,optimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage, memory leak optimization; Real-time Transcription Service 1.7 released,optimizatized the client-side;([docs](runtime/readme.md))
- 2024/01/09: The Funasr SDK for Windows version 2.0 has been released, featuring support for The offline file transcription service (CPU) of Mandarin 4.1, The offline file transcription service (CPU) of English 1.2, The real-time transcription service (CPU) of Mandarin 1.6. For more details, please refer to the official documentation or release notes([FunASR-Runtime-Windows](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
@@ -51,22 +55,31 @@
- 2023/07/17: BAT is released, which is a low-latency and low-memory-consumption RNN-T model. For more details, please refer to ([BAT](egs/aishell/bat)).
- 2023/06/26: ASRU2023 Multi-Channel Multi-Party Meeting Transcription Challenge 2.0 completed the competition and announced the results. For more details, please refer to ([M2MeT2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)).
</details>
<a name="Installation"></a>
## Installation
- Requirements
```text
python>=3.8
torch>=1.13
torchaudio
```
- Install for pypi
```shell
pip3 install -U funasr
```
Or install from source code
- Or install from source code
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip3 install -e ./
```
Install modelscope for the pretrained models (Optional)
- Install modelscope or huggingface_hub for the pretrained models (Optional)
```shell
pip3 install -U modelscope
pip3 install -U modelscope huggingface_hub
```
## Model Zoo
@@ -77,19 +90,19 @@
|                                                                                                         Model Name                                                                                                         |                     Task Details                      |          Training Data           | Parameters |
|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------:|:--------------------------------:|:----------:|
|          paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)  [🤗](https://huggingface.co/funasr/paraformer-tp) )           |  speech recognition, with timestamps, non-streaming   |      60000 hours, Mandarin       |    220M    |
|          paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)  [🤗](https://huggingface.co/funasr/paraformer-zh) )           |  speech recognition, with timestamps, non-streaming   |      60000 hours, Mandarin       |    220M    |
| <nobr>paraformer-zh-streaming <br> ( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) )</nobr> |             speech recognition, streaming             |      60000 hours, Mandarin       |    220M    |
|               paraformer-en <br> ( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) )                | speech recognition, without timestamps, non-streaming |       50000 hours, English       |    220M    |
|                            conformer-en <br> ( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) )                             |           speech recognition, non-streaming           |       50000 hours, English       |    220M    |
|                               ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) )                               |                punctuation restoration                |    100M, Mandarin and English    |    1.1G    |
|                               ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) )                               |                punctuation restoration                |    100M, Mandarin and English    |    290M    |
|                                   fsmn-vad <br> ( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) )                                   |               voice activity detection                | 5000 hours, Mandarin and English |    0.4M    | 
|                                     fa-zh <br> ( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) )                                     |                 timestamp prediction                  |       5000 hours, Mandarin       |    38M     | 
|                                       cam++ <br> ( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) )                                        |           speaker verification/diarization            |            5000 hours            |    7.2M    | 
|                                                  Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary)  [🍀](https://github.com/openai/whisper) )                                                  |  speech recognition, with timestamps, non-streaming   |           multilingual           |    1550 M    |
|                                                Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary)  [🍀](https://github.com/openai/whisper) )                                                 |  speech recognition, with timestamps, non-streaming   |           multilingual           |    1550 M    |
|                                         Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py)  [🤗](https://huggingface.co/Qwen/Qwen-Audio) )                                         |      audio-text multimodal models (pretraining)       |           multilingual           |  8B  |
|                   Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py)  [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) )                                                |          audio-text multimodal models (chat)          |           multilingual           |  8B  |
|                        emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)  [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) )                        |              speech emotion recongintion              |           40000 hours            |  300M  |
|                                 Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary)  [🍀](https://github.com/openai/whisper) )                                  |  speech recognition, with timestamps, non-streaming   |           multilingual           |   1550 M   |
|                                            Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary)  [🍀](https://github.com/openai/whisper) )                                            |  speech recognition, with timestamps, non-streaming   |           multilingual           |   1550 M   |
|                                               Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py)  [🤗](https://huggingface.co/Qwen/Qwen-Audio) )                                                |      audio-text multimodal models (pretraining)       |           multilingual           |     8B     |
|                                        Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py)  [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) )                                        |          audio-text multimodal models (chat)          |           multilingual           |     8B     |
|                              emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)  [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) )                               |              speech emotion recongintion              |           40000 hours            |    300M    |
@@ -153,6 +166,8 @@
```
Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
<details><summary>More Examples</summary>
### Voice Activity Detection (Non-Streaming)
```python
from funasr import AutoModel
@@ -211,9 +226,24 @@
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
### Speech Emotion Recognition
```python
from funasr import AutoModel
model = AutoModel(model="emotion2vec_plus_large")
wav_file = f"{model.model_path}/example/test.wav"
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
print(res)
```
More usages ref to [docs](docs/tutorial/README_zh.md), 
more examples ref to [demo](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
</details>
## Export ONNX
README_zh.md
@@ -2,7 +2,11 @@
(简体中文|[English](./README.md))
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=FunASR🤠&text2=💖%20A%20Fundamental%20End-to-End%20Speech%20Recognition%20Toolkit&width=800&height=210)](https://github.com/Akshay090/svg-banners)
[//]: # (# FunASR: A Fundamental End-to-End Speech Recognition Toolkit)
[![PyPI](https://img.shields.io/pypi/v/funasr)](https://pypi.org/project/funasr/)
@@ -35,6 +39,9 @@
- 2024/03/05:新增加Whisper-large-v3模型支持,多语言语音识别/翻译/语种识别,支持从 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)仓库下载,也支持从 [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)仓库下载模型。
- 2024/03/05: 中文离线文件转写服务 4.4、英文离线文件转写服务 1.5、中文实时语音听写服务 1.9 发布,docker镜像支持arm64平台,升级modelscope版本;详细信息参阅([部署文档](runtime/readme_cn.md))
- 2024/01/30:funasr-1.0发布,更新说明[文档](https://github.com/alibaba-damo-academy/FunASR/discussions/1319)
<details><summary>展开日志</summary>
- 2024/01/30:新增加情感识别 [模型链接](https://www.modelscope.cn/models/iic/emotion2vec_base_finetuned/summary),原始模型 [repo](https://github.com/ddlBoJack/emotion2vec).
- 2024/01/25: 中文离线文件转写服务 4.2、英文离线文件转写服务 1.3,优化vad数据处理方式,大幅降低峰值内存占用,内存泄漏优化;中文实时语音听写服务 1.7 发布,客户端优化;详细信息参阅([部署文档](runtime/readme_cn.md))
- 2024/01/09: funasr社区软件包windows 2.0版本发布,支持软件包中文离线文件转写4.1、英文离线文件转写1.2、中文实时听写服务1.6的最新功能,详细信息参阅([FunASR社区软件包windows版本](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
@@ -52,21 +59,33 @@
- 2023.07.17: BAT一种低延迟低内存消耗的RNN-T模型发布,详细信息参阅([BAT](egs/aishell/bat))
- 2023.06.26: ASRU2023 多通道多方会议转录挑战赛2.0完成竞赛结果公布,详细信息参阅([M2MeT2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2_cn/index.html))
</details>
<a name="安装教程"></a>
## 安装教程
- 安装funasr之前,确保已经安装了下面依赖环境:
```text
python>=3.8
torch>=1.13
torchaudio
```
- pip安装
```shell
pip3 install -U funasr
```
或者从源代码安装
- 或者从源代码安装
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip3 install -e ./
```
如果需要使用工业预训练模型,安装modelscope(可选)
如果需要使用工业预训练模型,安装modelscope与huggingface_hub(可选)
```shell
pip3 install -U modelscope
pip3 install -U modelscope huggingface huggingface_hub
```
## 模型仓库
@@ -78,11 +97,11 @@
|                                                                                                     模型名字                                                                                                      |        任务详情        |      训练数据      |  参数量   | 
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:--------------:|:------:|
|    paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)  [🤗](https://huggingface.co/funasr/paraformer-tp) )    |  语音识别,带时间戳输出,非实时   |   60000小时,中文   |  220M  |
|    paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)  [🤗](https://huggingface.co/funasr/paraformer-zh) )    |  语音识别,带时间戳输出,非实时   |   60000小时,中文   |  220M  |
| paraformer-zh-streaming <br> ( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) ) |      语音识别,实时       |   60000小时,中文   |  220M  |
|         paraformer-en <br> ( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) )         |      语音识别,非实时      |   50000小时,英文   |  220M  |
|                      conformer-en <br> ( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) )                      |      语音识别,非实时      |   50000小时,英文   |  220M  |
|                        ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) )                         |        标点恢复        |   100M,中文与英文   |  1.1B  |
|                        ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) )                         |        标点恢复        |   100M,中文与英文   |  290M  |
|                            fsmn-vad <br> ( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) )                             |     语音端点检测,实时      |  5000小时,中文与英文  |  0.4M  | 
|                              fa-zh <br> ( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) )                               |      字级别时间戳预测      |   50000小时,中文   |  38M   |
|                                 cam++ <br> ( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) )                                 |      说话人确认/分割      |     5000小时     |  7.2M  | 
@@ -148,6 +167,8 @@
注:`chunk_size`为流式延时配置,`[0,10,5]`表示上屏实时出字粒度为`10*60=600ms`,未来信息为`5*60=300ms`。每次推理输入为`600ms`(采样点数为`16000*0.6=960`),输出为对应文字,最后一个语音片段输入需要设置`is_final=True`来强制输出最后一个字。
<details><summary>更多例子</summary>
### 语音端点检测(非实时)
```python
from funasr import AutoModel
@@ -211,9 +232,24 @@
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
### 情感识别
```python
from funasr import AutoModel
model = AutoModel(model="emotion2vec_plus_large")
wav_file = f"{model.model_path}/example/test.wav"
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
print(res)
```
更详细([教程文档](docs/tutorial/README_zh.md)),
更多([模型示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
</details>
## 导出ONNX
### 从命令行导出
```shell
docs/images/wechat.png

funasr/auto/auto_frontend.py
@@ -52,7 +52,7 @@
        key_list, data_list = prepare_data_iterator(input, input_len=input_len)
        batch_size = kwargs.get("batch_size", 1)
        device = kwargs.get("device", "cpu")
        device = kwargs.get("device", "cuda")
        if device == "cpu":
            batch_size = 1
@@ -60,7 +60,7 @@
        result_list = []
        num_samples = len(data_list)
        pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
        # pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
        time0 = time.perf_counter()
        for beg_idx in range(0, num_samples, batch_size):
@@ -87,15 +87,23 @@
                speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
            )
            speech.to(device=device), speech_lengths.to(device=device)
            batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
            if kwargs.get("return_pt", True):
                speech, speech_lengths = speech.to(device=device), speech_lengths.to(device=device)
            else:
                speech, speech_lengths = speech.numpy(), speech_lengths.numpy()
            batch = {
                "input": speech,
                "input_len": speech_lengths,
                "key": key_batch,
                "data_type": "fbank",
            }
            result_list.append(batch)
            pbar.update(1)
            description = f"{meta_data}, "
            pbar.set_description(description)
            # pbar.update(1)
            # description = f"{meta_data}, "
            # pbar.set_description(description)
        time_end = time.perf_counter()
        pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
        # pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
        return result_list
funasr/auto/auto_model.py
@@ -42,8 +42,9 @@
    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
    chars = string.ascii_letters + string.digits
    if isinstance(data_in, str) and data_in.startswith("http"):  # url
        data_in = download_from_url(data_in)
    if isinstance(data_in, str):
        if data_in.startswith("http://") or data_in.startswith("https://"):  # url
            data_in = download_from_url(data_in)
    if isinstance(data_in, str) and os.path.exists(
        data_in
@@ -284,7 +285,7 @@
            with torch.no_grad():
                res = model.inference(**batch, **kwargs)
                if isinstance(res, (list, tuple)):
                    results = res[0]
                    results = res[0] if len(res) > 0 else [{"text": ""}]
                    meta_data = res[1] if len(res) > 1 else {}
            time2 = time.perf_counter()
@@ -358,6 +359,7 @@
            results_sorted = []
            if not len(sorted_data):
                results_ret_list.append({"key": key, "text": "", "timestamp": []})
                logging.info("decoding, utt: {}, empty speech".format(key))
                continue
funasr/datasets/audio_datasets/espnet_samplers.py
@@ -147,7 +147,9 @@
        start_idx = self.rank * batches_per_rank
        end_idx = start_idx + batches_per_rank
        rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
        self.batch_num = len(rank_batches)
        logging.info(
            f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
        )
funasr/download/name_maps_from_hub.py
@@ -12,10 +12,30 @@
    "Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
    "Whisper-large-v3": "iic/Whisper-large-v3",
    "Qwen-Audio": "Qwen/Qwen-Audio",
    "emotion2vec_plus_large": "iic/emotion2vec_plus_large",
    "emotion2vec_plus_base": "iic/emotion2vec_plus_base",
    "emotion2vec_plus_seed": "iic/emotion2vec_plus_seed",
}
name_maps_hf = {
    "": "",
    "paraformer": "funasr/paraformer-zh",
    "paraformer-zh": "funasr/paraformer-zh",
    "paraformer-en": "funasr/paraformer-zh",
    "paraformer-zh-streaming": "funasr/paraformer-zh-streaming",
    "fsmn-vad": "funasr/fsmn-vad",
    "ct-punc": "funasr/ct-punc",
    "ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
    "fa-zh": "funasr/fa-zh",
    "cam++": "funasr/campplus",
    "Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
    "Whisper-large-v3": "iic/Whisper-large-v3",
    "Qwen-Audio": "Qwen/Qwen-Audio",
    "emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
    "iic/emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
    "emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
    "iic/emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
    "emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
    "iic/emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
}
name_maps_openai = {
funasr/models/llm_asr/adaptor.py
@@ -1,5 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@@ -63,3 +65,64 @@
        query_proj = self.norm(self.linear(query_output.last_hidden_state))
        return query_proj
@tables.register("adaptor_classes", "Transformer")
class Transformer(nn.Module):
    def __init__(
        self, downsample_rate=2, encoder_dim=1280, llm_dim=4096, ffn_dim: int = 2048, **kwargs
    ):
        super().__init__()
        self.k = downsample_rate
        self.encoder_dim = encoder_dim
        self.llm_dim = llm_dim
        self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
        from funasr.models.transformer.encoder import EncoderLayer
        from funasr.models.transformer.attention import MultiHeadedAttention
        from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
        self.blocks = nn.ModuleList(
            [
                EncoderLayer(
                    llm_dim,
                    MultiHeadedAttention(
                        kwargs.get("attention_heads", 8),
                        llm_dim,
                        kwargs.get("attention_dropout_rate", 0.0),
                    ),
                    PositionwiseFeedForward(
                        llm_dim,
                        llm_dim // 4,
                        kwargs.get("dropout_rate", 0.0),
                    ),
                    kwargs.get("dropout_rate", 0.0),
                )
                for i in range(kwargs.get("n_layer", 2))
            ]
        )
    def forward(self, x, ilens=None):
        batch_size, seq_len, dim = x.size()
        # num_frames_to_discard = seq_len % self.k
        chunk_num = (seq_len - 1) // self.k + 1
        pad_num = chunk_num * self.k - seq_len
        x = F.pad(x, (0, 0, 0, pad_num, 0, 0), value=0.0)
        # if num_frames_to_discard > 0:
        #     x = x[:, :-num_frames_to_discard, :]
        seq_len = x.size(1)
        x = x.contiguous()
        x = x.view(batch_size, chunk_num, dim * self.k)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        olens = None
        olens = (ilens - 1) // self.k + 1
        masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
        for layer, block in enumerate(self.blocks):
            x, masks = block(x, masks)
        return x, olens
funasr/models/seaco_paraformer/export_meta.py
@@ -163,7 +163,11 @@
    dha_ids = dha_pred.max(-1)[-1]
    dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
    decoder_out = decoder_out * dha_mask + dha_pred * (1 - dha_mask)
    return decoder_out, pre_token_length, alphas
    # get predicted timestamps
    us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
    return decoder_out, pre_token_length, us_alphas, us_cif_peak
def export_backbone_dummy_inputs(self):
@@ -178,7 +182,7 @@
def export_backbone_output_names(self):
    return ["logits", "token_num", "alphas"]
    return ["logits", "token_num", "us_alphas", "us_cif_peak"]
def export_backbone_dynamic_axes(self):
@@ -190,6 +194,8 @@
        "bias_embed": {0: "batch_size", 1: "num_hotwords"},
        "logits": {0: "batch_size", 1: "logits_length"},
        "pre_acoustic_embeds": {1: "feats_length1"},
        "us_alphas": {0: "batch_size", 1: "alphas_length"},
        "us_cif_peak": {0: "batch_size", 1: "alphas_length"},
    }
funasr/models/sense_voice/decoder.py
@@ -360,6 +360,7 @@
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state
funasr/models/sense_voice/model.py
@@ -1264,15 +1264,29 @@
        if isinstance(task, str):
            task = [task]
        task = "".join([f"<|{x}|>" for x in task])
        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        sos = kwargs.get("model_conf").get("sos")
        if isinstance(sos, str):
            initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        language = DecodingOptions.get("language", None)
        language = None if language == "auto" else language
            language = DecodingOptions.get("language", None)
            language = None if language == "auto" else language
        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
        sos_int = tokenizer.encode(sos, allowed_special="all")
            sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
            sos_int = tokenizer.encode(sos, allowed_special="all")
        else:
            language = DecodingOptions.get("language", None)
            language = None if language == "auto" else language
            initial_prompt = kwargs.get("initial_prompt", f"{task}")
            initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
            initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
            sos_int = [sos] + initial_prompt_lid_int
        eos = kwargs.get("model_conf").get("eos")
        eos_int = tokenizer.encode(eos, allowed_special="all")
        if isinstance(eos, str):
            eos_int = tokenizer.encode(eos, allowed_special="all")
        else:
            eos_int = [eos]
        self.beam_search.sos = sos_int
        self.beam_search.eos = eos_int[0]
@@ -1298,7 +1312,7 @@
        self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
        encoder_out, encoder_out_lens = self.encode(
            speech[None, :, :].permute(0, 2, 1), speech_lengths
            speech[None, :, :], speech_lengths
        )
        if text_token_int is not None:
funasr/models/sense_voice/whisper_lib/model.py
@@ -27,9 +27,24 @@
    n_text_layer: int
# class LayerNorm(nn.LayerNorm):
#     def forward(self, x: Tensor) -> Tensor:
#         return super().forward(x.float()).type(x.dtype)
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, input):
        output = F.layer_norm(
            input.float(),
            self.normalized_shape,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        )
        return output.type_as(input)
class Linear(nn.Linear):
funasr/models/transformer/encoder.py
@@ -64,7 +64,7 @@
        stochastic_depth_rate=0.0,
    ):
        """Construct an EncoderLayer object."""
        super(EncoderLayer, self).__init__()
        super().__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
funasr/train_utils/trainer_ds.py
@@ -621,7 +621,6 @@
            self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
    def forward_step(self, model, batch, loss_dict={}):
        dtype = torch.bfloat16
        with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
            retval = model(**batch)
runtime/funasr_api/README.md
New file
@@ -0,0 +1,72 @@
# python funasr_api
This is the api for python to use funasr engine, only support 2pass server.
## For install
### Install websocket-client and ffmpeg
```shell
pip install websocket-client
apt install ffmpeg -y
```
#### recognizer examples
support many audio type as ffmpeg support, detail see FunASR/runtime/funasr_api/example.py
```shell
    # create an recognizer
    rcg = FunasrApi(
        uri="wss://www.funasr.com:10096/"
    )
    # recognizer by filepath
    text=rcg.rec_file("asr_example.mp3")
    print("recognizer by filepath result=",text)
    # recognizer by buffer
    # rec_buf(audio_buf,ffmpeg_decode=False),set ffmpeg_decode=True if audio is not PCM or WAV type
    with open("asr_example.wav", "rb") as f:
        audio_bytes = f.read()
    text=rcg.rec_buf(audio_bytes)
    print("recognizer by buffer result=",text)
```
#### streaming recognizer examples,use FunasrApi.audio2wav to covert to WAV type if need
```shell
    rcg = FunasrApi(
        uri="wss://www.funasr.com:10096/"
    )
    #define call_back function for msg
    def on_msg(msg):
       print("stream msg=",msg)
    stream=rcg.create_stream(msg_callback=on_msg)
    wav_path = "asr_example.wav"
    with open(wav_path, "rb") as f:
        audio_bytes = f.read()
    # use FunasrApi's audio2wav to covert other audio to PCM if needed
    #import os
    #from funasr_tools import FunasrTools
    #file_ext=os.path.splitext(wav_path)[-1].upper()
    #if not file_ext =="PCM" and not file_ext =="WAV":
    #       audio_bytes=FunasrTools.audio2wav(audio_bytes)
    stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
    chunk_num = (len(audio_bytes) - 1) // stride + 1
    for i in range(chunk_num):
        beg = i * stride
        data = audio_bytes[beg : beg + stride]
        stream.feed_chunk(data)
    final_result=stream.wait_for_end()
    print("asr_example.wav stream_result=",final_result)
```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service.
3. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service of offline model.
runtime/funasr_api/asr_example.mp3
Binary files differ
runtime/funasr_api/asr_example.wav
Binary files differ
runtime/funasr_api/example.py
New file
@@ -0,0 +1,70 @@
"""
  Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  Reserved. MIT License  (https://opensource.org/licenses/MIT)
  2023-2024 by zhaomingwork@qq.com
"""
from funasr_api import FunasrApi
import wave
def recognizer_example():
    # create an recognizer
    rcg = FunasrApi(
        uri="wss://www.funasr.com:10096/"
    )
    # recognizer by filepath
    text=rcg.rec_file("asr_example.mp3")
    print("recognizer by filepath result=",text)
    # recognizer by buffer
    # rec_buf(audio_buf,ffmpeg_decode=False),set ffmpeg_decode=True if audio is not PCM or WAV type
    with open("asr_example.wav", "rb") as f:
        audio_bytes = f.read()
    text=rcg.rec_buf(audio_bytes)
    print("recognizer by buffer result=",text)
def recognizer_stream_example():
    rcg = FunasrApi(
        uri="wss://www.funasr.com:10096/"
    )
    #define call_back function for msg
    def on_msg(msg):
       print("stream msg=",msg)
    stream=rcg.create_stream(msg_callback=on_msg)
    wav_path = "asr_example.wav"
    with open(wav_path, "rb") as f:
        audio_bytes = f.read()
    # use FunasrApi's audio2wav to covert other audio to PCM if needed
    #import os
    #from funasr_tools import FunasrTools
    #file_ext=os.path.splitext(wav_path)[-1].upper()
    #if not file_ext =="PCM" and not file_ext =="WAV":
    #       audio_bytes=FunasrTools.audio2wav(audio_bytes)
    stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
    chunk_num = (len(audio_bytes) - 1) // stride + 1
    for i in range(chunk_num):
        beg = i * stride
        data = audio_bytes[beg : beg + stride]
        stream.feed_chunk(data)
    final_result=stream.wait_for_end()
    print("asr_example.wav stream_result=",final_result)
if __name__ == "__main__":
    print("example for Funasr_websocket_recognizer")
    recognizer_stream_example()
    recognizer_example()
runtime/funasr_api/funasr_api.py
New file
@@ -0,0 +1,96 @@
"""
  Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  Reserved. MIT License  (https://opensource.org/licenses/MIT)
  2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import threading
import traceback
import json
import time
import numpy as np
from funasr_stream import FunasrStream
from funasr_tools import FunasrTools
from funasr_core import FunasrCore
# class for recognizer in websocket
class FunasrApi:
    """
    python asr recognizer lib
    """
    def __init__(
        self,
        uri="wss://www.funasr.com:10096/",
        timeout=1000,
        msg_callback=None,
    ):
        """
        uri: ws or wss server uri
        msg_callback: for message received
        timeout: timeout for get result
        """
        try:
            self.uri=uri
            self.timeout=timeout
            self.msg_callback=msg_callback
            self.funasr_core=None
        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()
    def create_stream(self,msg_callback=None):
        if self.funasr_core is not None:
            self.funasr_core.close()
        funasr_core=self.new_core(msg_callback=msg_callback)
        return FunasrStream(funasr_core)
    def new_core(self,msg_callback=None):
     try:
         if self.funasr_core is not None:
            self.funasr_core.close()
         if msg_callback==None:
            msg_callback=self.msg_callback
         funasr_core=FunasrCore(self.uri,msg_callback=msg_callback,timeout=self.timeout)
         funasr_core.new_connection()
         self.funasr_core=funasr_core
         return funasr_core
     except Exception as e:
            print("init_core",e)
            exit(0)
    # rec buffer, set ffmpeg_decode=True if audio is not PCM or WAV type
    def rec_buf(self,audio_buf,ffmpeg_decode=False):
       try:
           funasr_core=self.new_core()
           funasr_core.rec_buf(audio_buf,ffmpeg_decode=ffmpeg_decode)
           return funasr_core.get_result()
       except  Exception  as e:
            print("rec_file",e)
            return
    # rec file
    def rec_file(self,file_path):
       try:
           funasr_core=self.new_core()
           funasr_core.rec_file(file_path)
           return funasr_core.get_result()
       except  Exception  as e:
            print("rec_file",e)
            return
runtime/funasr_api/funasr_core.py
New file
@@ -0,0 +1,230 @@
"""
  Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  Reserved. MIT License  (https://opensource.org/licenses/MIT)
  2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import ssl
from websocket import ABNF
from websocket import create_connection
from queue import Queue
import threading
import traceback
import json
import time
import numpy as np
from funasr_tools import FunasrTools
# class for recognizer in websocket
class FunasrCore:
    """
    python asr recognizer lib
    """
    def __init__(
        self,
        uri="wss://www.funasr.com:10096/",
        msg_callback=None,
        timeout=1000,
    ):
        """
        uri: ws or wss server uri
        msg_callback: for message received
        timeout: timeout for get result
        """
        try:
            if uri.find("wss://"):
                       is_ssl=True
            elif uri.find("ws://"):
                 is_ssl=False
            else:
                print("not support uri",uri)
                exit(0)
            if is_ssl == True:
                ssl_context = ssl.SSLContext()
                ssl_context.check_hostname = False
                ssl_context.verify_mode = ssl.CERT_NONE
                uri = uri
                ssl_opt = {"cert_reqs": ssl.CERT_NONE}
            else:
                uri = uri
                ssl_context = None
                ssl_opt = None
            self.ssl_opt=ssl_opt
            self.ssl_context=ssl_context
            self.uri = uri
            print("connect to url", uri)
            self.msg_callback=msg_callback
            self.is_final=False
            self.rec_text=""
            self.timeout=timeout
            self.rec_file_len=0
            self.connect_state=0
        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()
    def new_connection(self):
       try:
         self.websocket = create_connection(self.uri, ssl=self.ssl_context, sslopt=self.ssl_opt)
         self.is_final=False
         self.rec_text=""
         self.rec_file_len=0
         self.connect_state=0
         message = json.dumps(
                {
                    "mode": "2pass",
                    "chunk_size": [int(x) for x in "0,10,5".split(",")],
                    "encoder_chunk_look_back": 4,
                    "decoder_chunk_look_back": 1,
                    "chunk_interval": 10,
                    "wav_name": "funasr_api",
                    "is_speaking": True,
                }
            )
         self.websocket.send(message)
         self.connect_state=1
         # thread for receive message
         self.thread_msg = threading.Thread(
                target=FunasrCore.thread_rec_msg, args=(self,)
            )
         self.thread_msg.start()
         print("new_connection: ",message)
       except Exception as e:
            print("new_connection",e)
    # threads for rev msg
    def thread_rec_msg(self):
        try:
            while True:
                if  self.connect_state==0:
                    time.sleep(0.1)
                    continue
                if self.connect_state==2:
                    break
                msg = self.websocket.recv()
                if msg is None or len(msg) == 0:
                    continue
                msg = json.loads(msg)
                if msg['is_final']==True:
                    self.is_final=True
                if msg['mode']=='2pass-offline':
                   self.rec_text=self.rec_text+msg['text']
                if not self.msg_callback is None:
                   self.msg_callback(msg)
        except Exception as e:
            #print("client closed")
            return
    # feed data to asr engine in stream way
    def feed_chunk(self, chunk):
        try:
            self.websocket.send(chunk, ABNF.OPCODE_BINARY)
            return
        except:
            print("feed chunk error")
            return
    def close(self):
         self.connect_state==2
         self.websocket.close()
    def rec_buf(self,audio_bytes,ffmpeg_decode=False):
       try:
        if ffmpeg_decode:
            audio_bytes=FunasrTools.audio2wav(audio_bytes)
        self.rec_file_len=len(audio_bytes)
        stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
        chunk_num = (len(audio_bytes) - 1) // stride + 1
        for i in range(chunk_num):
            beg = i * stride
            data = audio_bytes[beg : beg + stride]
            self.feed_chunk(data)
        return self.get_result()
       except  Exception  as e:
            print("rec_file",e)
            return
    # rec file
    def rec_file(self,file_path):
       try:
        #self.new_connection()
        import os
        file_ext=os.path.splitext(file_path)[-1].upper()
        with  open(file_path, "rb") as f:
           audio_bytes = f.read()
        if not file_ext =="PCM" and not file_ext =="WAV":
           audio_bytes=FunasrTools.audio2wav(audio_bytes)
        if audio_bytes==None:
           print("error, ffmpeg can not decode such file!")
           exit(0)
        return self.rec_buf(audio_bytes)
       except  Exception  as e:
            print("rec_file",e)
            return
    def wait_for_result(self):
       try:
        timeout=self.timeout
        file_dur=self.rec_file_len/16000/2*100
        if file_dur>timeout:
           timeout=file_dur
           self.timeout=timeout
        #print("wait_for_result timeout=",timeout)
        # if file_dur==0 means in stream way and no timeout
        while(self.is_final==False and (timeout>0 or file_dur==0 )):
            time.sleep(0.01)
            timeout=timeout-1
        if timeout<=0 and not file_dur==0:
           print("time out!",self.timeout)
       except Exception  as e:
            print("wait_for_result",e)
            return
    def get_result(self):
       try:
        message = json.dumps({"is_speaking": False})
        self.websocket.send(message)
        self.wait_for_result()
        self.close()
        # return the  msg
        return self.rec_text
       except Exception  as e:
            #print("get_result ",e)
            return self.rec_text
runtime/funasr_api/funasr_stream.py
New file
@@ -0,0 +1,72 @@
"""
  Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  Reserved. MIT License  (https://opensource.org/licenses/MIT)
  2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import threading
import traceback
import json
import time
# class for recognizer in websocket
class FunasrStream:
    """
    python asr recognizer lib
    """
    def __init__(
        self,
        funasr_core
    ):
        """
        uri: ws or wss server uri
        msg_callback: for message received
        timeout: timeout for get result
        """
        try:
            self.funasr_core=funasr_core
        except Exception as e:
            print("FunasrStream init Exception:", e)
            traceback.print_exc()
    # feed data to asr engine in stream way
    def feed_chunk(self, chunk):
        try:
            if self.funasr_core is None:
                print("error in stream, funasr_core is None")
                exit(0)
            self.funasr_core.feed_chunk(chunk)
            return
        except:
            print("feed chunk error")
            return
    # return all result for this stream
    def wait_for_end(self):
       try:
        message = json.dumps({"is_speaking": False})
        self.funasr_core.websocket.send(message)
        self.funasr_core.wait_for_result()
        self.funasr_core.close()
        # return the  msg
        return self.funasr_core.rec_text
       except Exception  as e:
            print("error get_final_result ",e)
            return ""
runtime/funasr_api/funasr_tools.py
New file
@@ -0,0 +1,84 @@
"""
  Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
  Reserved. MIT License  (https://opensource.org/licenses/MIT)
  2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import threading
import traceback
import time
# class for recognizer in websocket
class FunasrTools:
    """
    python asr recognizer lib
    """
    def __init__(
        self
    ):
        """
        """
        try:
              if FunasrTools.check_ffmpeg()==False:
                 print("pls instal ffmpeg firest, in ubuntu, you can type apt install -y ffmpeg")
                 exit(0)
        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()
    # check ffmpeg installed
    @staticmethod
    def check_ffmpeg():
        import subprocess
        try:
            subprocess.run(['ffmpeg', '-version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            return True
        except FileNotFoundError:
            return False
    # use ffmpeg to convert audio to wav
    @staticmethod
    def audio2wav(audiobuf):
     try:
      import os
      import subprocess
      if FunasrTools.check_ffmpeg()==False:
         print("pls instal ffmpeg firest, in ubuntu, you can type apt install -y ffmpeg")
         exit(0)
         return
      ffmpeg_target_to_outwav = ["ffmpeg", "-i", '-',  "-ac", "1", "-ar", "16000",  "-f", "wav", "pipe:1"]
      pipe_to = subprocess.Popen(ffmpeg_target_to_outwav,
                       stdin=subprocess.PIPE,
                       stdout=subprocess.PIPE,
                       stderr=subprocess.PIPE)
      wavbuf, err = pipe_to.communicate(audiobuf)
      if str(err).find("Error")>=0 or str(err).find("Unknown")>=0 or str(err).find("Invalid")>=0:
            print("ffmpeg err",err)
            return None
      return wavbuf
     except Exception as e:
            print("audio2wav",e)
            return None
runtime/http/CMakeLists.txt
New file
@@ -0,0 +1,142 @@
cmake_minimum_required(VERSION 3.16)
project(FunASRWebscoket)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
option(ENABLE_HTTP "Whether to build http server" ON)
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
if(WIN32)
  file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h
    ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/export.h
    ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/logging.h
    ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/raw_logging.h
    ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/stl_logging.h
    ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/vlog_is_on.h)
else()
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
endif()
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
if(ENABLE_HTTP)
  # cmake_policy(SET CMP0135 NEW)
  include(FetchContent)
  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/asio/asio )
    FetchContent_Declare(asio
      URL   https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz
    SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/asio
    )
    FetchContent_MakeAvailable(asio)
  endif()
  include_directories(${PROJECT_SOURCE_DIR}/third_party/asio/asio/include)
  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
    FetchContent_Declare(json
      URL   https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
    SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
    )
    FetchContent_MakeAvailable(json)
  endif()
  include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
endif()
if(ENABLE_PORTAUDIO)
  include(FetchContent)
  set(portaudio_URL  "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz")
  set(portaudio_URL2 "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/pa_stable_v190700_20210406.tgz")
  set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def")
  FetchContent_Declare(portaudio
    URL
      ${portaudio_URL}
      ${portaudio_URL2}
    URL_HASH          ${portaudio_HASH}
  )
  FetchContent_GetProperties(portaudio)
  if(NOT portaudio_POPULATED)
    message(STATUS "Downloading portaudio from ${portaudio_URL}")
    FetchContent_Populate(portaudio)
  endif()
  message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}")
  message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}")
  add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL)
  if(NOT WIN32)
    target_compile_options(portaudio PRIVATE "-Wno-deprecated-declarations")
  else()
    install(TARGETS portaudio DESTINATION ..)
  endif()
endif()
# Include generated *.pb.h files
link_directories(${ONNXRUNTIME_DIR}/lib)
link_directories(${FFMPEG_DIR}/lib)
if(ENABLE_GLOG)
    include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src)
    set(BUILD_TESTING OFF)
    add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
    include_directories(${glog_BINARY_DIR})
endif()
if(ENABLE_FST)
    # fst depend on glog and gflags
    include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags)
    add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags gflags)
    include_directories(${gflags_BINARY_DIR}/include)
    # the following openfst if cloned from https://github.com/kkm000/openfst.git
    # with some patch to fix the make errors.
    add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/openfst openfst)
    include_directories(${openfst_SOURCE_DIR}/src/include)
    if(WIN32)
    include_directories(${openfst_SOURCE_DIR}/src/lib)
    endif()
endif()
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/src)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include/limonp/include)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi kaldi)
# install openssl first apt-get install libssl-dev
find_package(OpenSSL REQUIRED)
message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
# 获取项目中所有包含文件夹的路径
get_property(includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
# 遍历并输出每个包含文件夹的路径
foreach(include ${includes})
  message("Include directory: ${include}")
endforeach()
add_subdirectory(bin)
runtime/http/bin/CMakeLists.txt
New file
@@ -0,0 +1,23 @@
if(WIN32)
  include_directories(${ONNXRUNTIME_DIR}/include)
  include_directories(${FFMPEG_DIR}/include)
  include_directories(${OPENSSL_ROOT_DIR}//include)
  link_directories(${OPENSSL_ROOT_DIR}/lib)
  add_definitions(-D_WEBSOCKETPP_CPP11_RANDOM_DEVICE_)
  add_definitions(-D_WEBSOCKETPP_CPP11_TYPE_TRAITS_)
  add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/bigobj>")
  add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
  SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
endif()
find_package(ZLIB REQUIRED)
file(GLOB SRC_FILES "*.cpp")
add_executable(funasr-http-server ${SRC_FILES} ${RELATION_SOURCE})
target_link_libraries(funasr-http-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
runtime/http/bin/asr_sessions.h
New file
@@ -0,0 +1,20 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
// FUNASR_MESSAGE define the needed message between funasr engine and http server
#ifndef HTTP_SERVER2_SESSIONS_HPP
#define HTTP_SERVER2_SESSIONS_HPP
#include "funasrruntime.h"
#include "nlohmann/json.hpp"
#include <atomic>
typedef struct {
  nlohmann::json msg;
  std::shared_ptr<std::vector<char>> samples;
  std::shared_ptr<std::vector<std::vector<float>>> hotwords_embedding=nullptr;
  FUNASR_DEC_HANDLE decoder_handle=nullptr;
  std::atomic<int> status;
} FUNASR_MESSAGE;
#endif // HTTP_SERVER2_REQUEST_PARSER_HPP
runtime/http/bin/connection.cpp
New file
@@ -0,0 +1,196 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
//
// connection.cpp
// copy some codes from  http://www.boost.org/
#include "connection.hpp"
#include <thread>
#include <utility>
namespace http {
namespace server2 {
//std::ofstream fwout("out.data", std::ios::binary);
std::shared_ptr<FUNASR_MESSAGE> &connection::get_data_msg() { return data_msg; }
connection::connection(asio::ip::tcp::socket socket,
                       asio::io_context &io_decoder, int connection_id,
                       std::shared_ptr<ModelDecoder> model_decoder)
    : socket_(std::move(socket)),
      io_decoder(io_decoder),
      connection_id(connection_id),
      model_decoder(model_decoder)
{
  s_timer = std::make_shared<asio::steady_timer>(io_decoder);
}
void connection::setup_timer() {
  if (data_msg->status == 1) return;
  s_timer->expires_after(std::chrono::seconds(3));
  s_timer->async_wait([=](const asio::error_code &ec) {
    if (!ec) {
      std::cout << "time is out!" << std::endl;
      if (data_msg->status == 1) return;
      data_msg->status = 1;
      s_timer->cancel();
      auto wf = std::bind(&connection::write_back, std::ref(*this), "");
      // close the connection
      strand_->post(wf);
    }
  });
}
void connection::start() {
  std::lock_guard<std::mutex> lock(m_lock);  // for threads safty
  try {
    data_msg = std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for
                                                    // new connection
    data_msg->samples = std::make_shared<std::vector<char>>();
    //data_msg->samples->reserve(16000*20);
    data_msg->msg = nlohmann::json::parse("{}");
    data_msg->msg["wav_format"] = "pcm";
    data_msg->msg["wav_name"] = "wav-default-id";
    data_msg->msg["itn"] = true;
    data_msg->msg["audio_fs"] = 16000;  // default is 16k
    data_msg->msg["access_num"] = 0;    // the number of access for this object,
                                        // when it is 0, we can free it saftly
    data_msg->msg["is_eof"] = false;
    data_msg->status = 0;
    strand_ = std::make_shared<asio::io_context::strand>(io_decoder);
    FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(
        model_decoder->get_asr_handle(), ASR_OFFLINE, global_beam_, lattice_beam_, am_scale_);
    data_msg->decoder_handle = decoder_handle;
    if (data_msg->hotwords_embedding == nullptr) {
      std::unordered_map<std::string, int> merged_hws_map;
      std::string nn_hotwords = "";
      if (true) {
        std::string json_string = "{}";
        if (!json_string.empty()) {
          nlohmann::json json_fst_hws;
          try {
            json_fst_hws = nlohmann::json::parse(json_string);
            if (json_fst_hws.type() == nlohmann::json::value_t::object) {
              // fst
              try {
                std::unordered_map<std::string, int> client_hws_map =
                    json_fst_hws;
                merged_hws_map.insert(client_hws_map.begin(),
                                      client_hws_map.end());
              } catch (const std::exception &e) {
                std::cout << e.what();
              }
            }
          } catch (std::exception const &e) {
            std::cout << e.what();
            // nn
            std::string client_nn_hws = "{}";
            nn_hotwords += " " + client_nn_hws;
            std::cout << "nn hotwords: " << client_nn_hws;
          }
        }
      }
      merged_hws_map.insert(hws_map_.begin(), hws_map_.end());
      // fst
      std::cout << "hotwords: ";
      for (const auto &pair : merged_hws_map) {
        nn_hotwords += " " + pair.first;
        std::cout << pair.first << " : " << pair.second;
      }
      FunWfstDecoderLoadHwsRes(data_msg->decoder_handle, fst_inc_wts_,
                               merged_hws_map);
      // nn
      std::vector<std::vector<float>> new_hotwords_embedding =
          CompileHotwordEmbedding(model_decoder->get_asr_handle(), nn_hotwords);
      data_msg->hotwords_embedding =
          std::make_shared<std::vector<std::vector<float>>>(
              new_hotwords_embedding);
    }
    file_parse = std::make_shared<http::server2::file_parser>(data_msg);
    do_read();
  } catch (const std::exception &e) {
    std::cout << "error:" << e.what();
  }
}
void connection::write_back(std::string str) {
  s_timer->cancel();
  std::cout << "jsonresult=" << data_msg->msg["asr_result"].dump() << std::endl;
  reply_ = reply::stock_reply(
      data_msg->msg["asr_result"].dump());  // reply::stock_reply();
  do_write();
}
void connection::do_read() {
  // status==1 means time out
  if (data_msg->status == 1) return;
  s_timer->cancel();
  setup_timer();
  auto self(shared_from_this());
  socket_.async_read_some(
      asio::buffer(buffer_),
      [this, self](asio::error_code ec, std::size_t bytes_transferred) {
        if (!ec) {
          auto is = std::begin(buffer_);
          auto ie = std::next(is, bytes_transferred);
          http::server2::file_parser::result_type rtype =
              file_parse->parse_file(is, ie);
          if (rtype == http::server2::file_parser::result_type::ok) {
            //fwout.write(data_msg->samples->data(),data_msg->samples->size());
            //fwout.flush();
            auto wf = std::bind(&connection::write_back, std::ref(*this), "aa");
            auto f = std::bind(&ModelDecoder::do_decoder,
                               std::ref(*model_decoder), std::ref(data_msg));
            // for decode task
            strand_->post(f);
            // for close task
            strand_->post(wf);
            //  std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
          }
          do_read();
        }
      });
}
void connection::do_write() {
  auto self(shared_from_this());
  asio::async_write(socket_, reply_.to_buffers(),
                    [this, self](asio::error_code ec, std::size_t) {
                      if (!ec) {
                        // Initiate graceful connection closure.
                        asio::error_code ignored_ec;
                        socket_.shutdown(asio::ip::tcp::socket::shutdown_both,
                                         ignored_ec);
                      }
                      // No new asynchronous operations are started. This means
                      // that all shared_ptr references to the connection object
                      // will disappear and the object will be destroyed
                      // automatically after this handler returns. The
                      // connection class's destructor closes the socket.
                    });
}
}  // namespace server2
}  // namespace http
runtime/http/bin/connection.hpp
New file
@@ -0,0 +1,104 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
//
// copy some codes from  http://www.boost.org/
//
#ifndef HTTP_SERVER2_CONNECTION_HPP
#define HTTP_SERVER2_CONNECTION_HPP
#include <array>
#include <asio.hpp>
#include <atomic>
#include <iostream>
#include <memory>
#include "reply.hpp"
#include <fstream>
#include "file_parse.hpp"
#include "model-decoder.h"
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
namespace http {
namespace server2 {
/// Represents a single connection from a client.
class connection : public std::enable_shared_from_this<connection> {
 public:
  connection(const connection &) = delete;
  connection &operator=(const connection &) = delete;
  ~connection() { std::cout << "one connection is close()" << std::endl; };
  /// Construct a connection with the given socket.
  explicit connection(asio::ip::tcp::socket socket,
                      asio::io_context &io_decoder, int connection_id,
                      std::shared_ptr<ModelDecoder> model_decoder);
  /// Start the first asynchronous operation for the connection.
  void start();
  std::shared_ptr<FUNASR_MESSAGE> &get_data_msg();
  void write_back(std::string str);
 private:
  /// Perform an asynchronous read operation.
  void do_read();
  /// Perform an asynchronous write operation.
  void do_write();
  void do_decoder();
  void setup_timer();
  /// Socket for the connection.
  asio::ip::tcp::socket socket_;
  /// Buffer for incoming data.
  std::array<char, 8192> buffer_;
  /// for time out
  std::shared_ptr<asio::steady_timer> s_timer;
  std::shared_ptr<ModelDecoder> model_decoder;
  int connection_id = 0;
  /// The reply to be sent back to the client.
  reply reply_;
  asio::io_context &io_decoder;
  std::shared_ptr<FUNASR_MESSAGE> data_msg;
  std::mutex m_lock;
  std::shared_ptr<asio::io_context::strand> strand_;
  std::shared_ptr<http::server2::file_parser> file_parse;
};
typedef std::shared_ptr<connection> connection_ptr;
}  // namespace server2
}  // namespace http
#endif  // HTTP_SERVER2_CONNECTION_HPP
runtime/http/bin/file_parse.cpp
New file
@@ -0,0 +1,29 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
#include "file_parse.hpp"
namespace http {
namespace server2 {
file_parser::file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg)
:data_msg(data_msg)
{
    now_state=start;
}
} // namespace server2
} // namespace http
runtime/http/bin/file_parse.hpp
New file
@@ -0,0 +1,234 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
// ~~~~~~~~~~~~~~~~~~
#ifndef HTTP_SERVER2_REQUEST_FILEPARSER_HPP
#define HTTP_SERVER2_REQUEST_FILEPARSER_HPP
#include <iostream>
#include <memory>
#include <tuple>
#include "asr_sessions.h"
namespace http {
namespace server2 {
/// Parser for incoming requests.
class file_parser {
 public:
  /// Construct ready to parse the request method.
  explicit file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg);
  /// Result of parse.
  enum result_type { start, in_boundary, data, ok };
  template <typename InputIterator>
  void parse_one_line(InputIterator &is, InputIterator &ie, InputIterator &it) {
    if (is != it) {
      is = it;
    }
    if (*it == '\n') {
      is = std::next(is);
    }
    it = std::find(is, ie, '\n');
    std::string str(is, it);
  }
  std::string trim_name(std::string raw_string) {
    int pos = raw_string.find('\"');
    if (pos != std::string::npos) {
      raw_string = raw_string.substr(pos + 1);
      pos = raw_string.find('\"');
      raw_string = raw_string.substr(0, pos);
    }
    return raw_string;
  }
  std::string parese_file_ext(std::string file_name) {
    int pos = file_name.rfind('.');
    std::string ext = "";
    if (pos != std::string::npos) ext = file_name.substr(pos + 1);
    return ext;
  }
  template <typename InputIterator>
  int parse_data_content(InputIterator is, InputIterator ie, InputIterator it) {
    int len = std::distance(it + 1, ie);
    if (len <= 0) {
      return 0;
    }
    std::string str(it + 1, ie);
    // check if at the end, "--boundary--" need +4 for "--"
    if (len == boundary.length() + 4)
    {
      std::string str(it + 1, ie);
      // std::cout << "len good=" << str << std::endl;
      if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\n') {
        // remove '\n' in boundary
        boundary = boundary.substr(0, boundary.length() - 2);
      }
      if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\r') {
        // remove '\r' in boundary
        boundary = boundary.substr(0, boundary.length() - 2);
      }
      auto found_boundary = str.find(boundary);
      if (found_boundary == std::string::npos) {
        std::cout << "not found end boundary!=" << found_boundary << std::endl;
        return 0;
      }
      // remove the end of data that contains '\n' or '\r'
      int last_sub = 0;
      if (*(it) == '\n') {
        last_sub++;
      }
      int lasts_len = std::distance(it, ie);
      data_msg->samples->erase(data_msg->samples->end() - last_sub - lasts_len,
                               data_msg->samples->end());
      std::cout << "one file finished, file size=" << data_msg->samples->size()
                << std::endl;
      return 1;
    }
  }
  template <typename InputIterator>
  void parse_boundary_content(InputIterator is, InputIterator ie,
                            InputIterator it) {
    parse_one_line(is, ie, it);
    std::string str;
    while (it != ie) {
      str = std::string(is, it);
      auto found_content = str.find("Content-Disposition:");
      auto found_filename = str.find("filename=");
      if (found_content != std::string::npos &&
          found_filename != std::string::npos) {
        std::string file_name =
            str.substr(found_filename + 9, std::string::npos);
        file_name = trim_name(file_name);
        std::string ext = parese_file_ext(file_name);
        if (file_name.find(".wav") != std::string::npos) {
          std::cout << "set wav_format=pcm, file_name=" << file_name
                    << std::endl;
          data_msg->msg["wav_format"] = "pcm";
        } else {
          std::cout << "set wav_format=" << ext << ", file_name=" << file_name
                    << std::endl;
          data_msg->msg["wav_format"] = ext;
        }
        data_msg->msg["wav_name"] = file_name;
        now_state = data;
      } else {
        auto found_content = str.find("Content-Disposition:");
        auto found_name = str.find("name=");
        if (found_content != std::string::npos &&
            found_name != std::string::npos) {
          std::string name = str.substr(found_name + 5, std::string::npos);
          name = trim_name(name);
          parse_one_line(is, ie, it);
          if (*it == '\n') it++;
          parse_one_line(is, ie, it);
          str = std::string(is, it);
          std::cout << "para: name=" << name << ",value=" << str << std::endl;
        }
      }
      parse_one_line(is, ie, it);
      if (now_state == data && std::distance(is, it) <= 2) {
        break;
      }
    }
    if (now_state == data) {
      if (*it == '\n') it++;
      data_msg->samples->insert(data_msg->samples->end(), it,
                                it + std::distance(it, ie));
      // it=ie;
    }
  }
  template <typename InputIterator>
  result_type parse_file(InputIterator is, InputIterator ie) {
    if (now_state == data) {
      data_msg->samples->insert(data_msg->samples->end(), is, ie);
    }
    auto it = is;
    while (it != ie) {
      std::string str(is, it);
      parse_one_line(is, ie, it);
      if (now_state == data) {
        // for data end search
        int ret = parse_data_content(is, ie, it);
        if (ret == 0) continue;
        return ok;
      } else {
        std::string str(is, it + 1);
        if (now_state == start) {
          auto found_boundary = str.find("Content-Length:");
          if (found_boundary != std::string::npos) {
            std::string file_len =
                str.substr(found_boundary + 15, std::string::npos);
            data_msg->samples->reserve(std::stoi(file_len));
          }
          found_boundary = str.find("boundary=");
          if (found_boundary != std::string::npos) {
            boundary = str.substr(found_boundary + 9, std::string::npos);
            now_state = in_boundary;
          }
        } else if (now_state == in_boundary) {
          // for file header
          auto found_boundary = str.find(boundary);
          if (found_boundary != std::string::npos) {
            parse_boundary_content(is, ie, it);
          }
        }
      }
    }
    return now_state;
  }
 private:
  std::shared_ptr<FUNASR_MESSAGE> data_msg;
  result_type now_state;
  std::string boundary = "";
};
}  // namespace server2
}  // namespace http
#endif  // HTTP_SERVER2_REQUEST_FILEPARSER_HPP
runtime/http/bin/funasr-http-main.cpp
New file
@@ -0,0 +1,523 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
#include "funasr-http-main.hpp"
#ifdef _WIN32
#include "win_func.h"
#else
#include <unistd.h>
#endif
#include <fstream>
#include "util.h"
// hotwords
std::unordered_map<std::string, int> hws_map_;
int fst_inc_wts_ = 20;
float global_beam_, lattice_beam_, am_scale_;
using namespace std;
void GetValue(TCLAP::ValueArg<std::string> &value_arg, string key,
              std::map<std::string, std::string> &model_path) {
  model_path.insert({key, value_arg.getValue()});
  LOG(INFO) << key << " : " << value_arg.getValue();
}
FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path,
                      int thread_num) {
  try {
    // init model with api
    FUNASR_HANDLE asr_handle = FunOfflineInit(model_path, thread_num);
    LOG(INFO) << "model successfully inited";
    LOG(INFO) << "initAsr run check_and_clean_connection";
    // std::thread
    // clean_thread(&ModelDecoderSrv::check_and_clean_connection,this);
    // clean_thread.detach();
    LOG(INFO) << "initAsr run check_and_clean_connection finished";
    return asr_handle;
  } catch (const std::exception &e) {
    LOG(INFO) << e.what();
    // return nullptr;
  }
}
int main(int argc, char *argv[]) {
#ifdef _WIN32
#include <windows.h>
  SetConsoleOutputCP(65001);
#endif
  try {
    google::InitGoogleLogging(argv[0]);
    FLAGS_logtostderr = true;
    std::string offline_version = "";
#ifdef _WIN32
    offline_version = "0.1.0";
#endif
    TCLAP::CmdLine cmd("funasr-wss-server", ' ', offline_version);
    TCLAP::ValueArg<std::string> download_model_dir(
        "", "download-model-dir",
        "Download model from Modelscope to download_model_dir", false,
        "/workspace/models", "string");
    TCLAP::ValueArg<std::string> model_dir(
        "", OFFLINE_MODEL_DIR,
        "default: "
        "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx, "
        "the asr model path, which "
        "contains model_quant.onnx, config.yaml, am.mvn",
        false,
        "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx",
        "string");
    TCLAP::ValueArg<std::string> model_revision("", "offline-model-revision",
                                                "ASR offline model revision",
                                                false, "v2.0.4", "string");
    TCLAP::ValueArg<std::string> quantize(
        "", QUANTIZE,
        "true (Default), load the model of model_quant.onnx in model_dir. If "
        "set "
        "false, load the model of model.onnx in model_dir",
        false, "true", "string");
    TCLAP::ValueArg<std::string> vad_dir(
        "", VAD_DIR,
        "default: damo/speech_fsmn_vad_zh-cn-16k-common-onnx, the vad model "
        "path, which contains "
        "model_quant.onnx, vad.yaml, vad.mvn",
        false, "damo/speech_fsmn_vad_zh-cn-16k-common-onnx", "string");
    TCLAP::ValueArg<std::string> vad_revision(
        "", "vad-revision", "VAD model revision", false, "v2.0.4", "string");
    TCLAP::ValueArg<std::string> vad_quant(
        "", VAD_QUANT,
        "true (Default), load the model of model_quant.onnx in vad_dir. If set "
        "false, load the model of model.onnx in vad_dir",
        false, "true", "string");
    TCLAP::ValueArg<std::string> punc_dir(
        "", PUNC_DIR,
        "default: "
        "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx, "
        "the punc model path, which contains "
        "model_quant.onnx, punc.yaml",
        false,
        "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx",
        "string");
    TCLAP::ValueArg<std::string> punc_revision(
        "", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
    TCLAP::ValueArg<std::string> punc_quant(
        "", PUNC_QUANT,
        "true (Default), load the model of model_quant.onnx in punc_dir. If "
        "set "
        "false, load the model of model.onnx in punc_dir",
        false, "true", "string");
    TCLAP::ValueArg<std::string> itn_dir(
        "", ITN_DIR,
        "default: thuduj12/fst_itn_zh, the itn model path, which contains "
        "zh_itn_tagger.fst, zh_itn_verbalizer.fst",
        false, "", "string");
    TCLAP::ValueArg<std::string> itn_revision(
        "", "itn-revision", "ITN model revision", false, "v1.0.1", "string");
    TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
                                           "0.0.0.0", "string");
    TCLAP::ValueArg<int> port("", "port", "port", false, 80, "int");
    TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
                                       false, 8, "int");
    TCLAP::ValueArg<int> decoder_thread_num(
        "", "decoder-thread-num", "decoder thread num", false, 32, "int");
    TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
                                          "model thread num", false, 1, "int");
    TCLAP::ValueArg<std::string> certfile(
        "", "certfile",
        "default: ../../../ssl_key/server.crt, path of certficate for WSS "
        "connection. if it is empty, it will be in WS mode.",
        false, "../../../ssl_key/server.crt", "string");
    TCLAP::ValueArg<std::string> keyfile(
        "", "keyfile",
        "default: ../../../ssl_key/server.key, path of keyfile for WSS "
        "connection",
        false, "../../../ssl_key/server.key", "string");
    TCLAP::ValueArg<float> global_beam("", GLOB_BEAM,
                                       "the decoding beam for beam searching ",
                                       false, 3.0, "float");
    TCLAP::ValueArg<float> lattice_beam(
        "", LAT_BEAM, "the lattice generation beam for beam searching ", false,
        3.0, "float");
    TCLAP::ValueArg<float> am_scale("", AM_SCALE,
                                    "the acoustic scale for beam searching ",
                                    false, 10.0, "float");
    TCLAP::ValueArg<std::string> lm_dir(
        "", LM_DIR,
        "the LM model path, which contains compiled models: TLG.fst, "
        "config.yaml ",
        false, "", "string");
    TCLAP::ValueArg<std::string> lm_revision(
        "", "lm-revision", "LM model revision", false, "v1.0.2", "string");
    TCLAP::ValueArg<std::string> hotword(
        "", HOTWORD,
        "the hotword file, one hotword perline, Format: Hotword Weight (could "
        "be: 阿里巴巴 20)",
        false, "/workspace/resources/hotwords.txt", "string");
    TCLAP::ValueArg<std::int32_t> fst_inc_wts(
        "", FST_INC_WTS, "the fst hotwords incremental bias", false, 20,
        "int32_t");
    // add file
    cmd.add(hotword);
    cmd.add(fst_inc_wts);
    cmd.add(global_beam);
    cmd.add(lattice_beam);
    cmd.add(am_scale);
    cmd.add(certfile);
    cmd.add(keyfile);
    cmd.add(download_model_dir);
    cmd.add(model_dir);
    cmd.add(model_revision);
    cmd.add(quantize);
    cmd.add(vad_dir);
    cmd.add(vad_revision);
    cmd.add(vad_quant);
    cmd.add(punc_dir);
    cmd.add(punc_revision);
    cmd.add(punc_quant);
    cmd.add(itn_dir);
    cmd.add(itn_revision);
    cmd.add(lm_dir);
    cmd.add(lm_revision);
    cmd.add(listen_ip);
    cmd.add(port);
    cmd.add(io_thread_num);
    cmd.add(decoder_thread_num);
    cmd.add(model_thread_num);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
    GetValue(punc_quant, PUNC_QUANT, model_path);
    GetValue(itn_dir, ITN_DIR, model_path);
    GetValue(lm_dir, LM_DIR, model_path);
    GetValue(hotword, HOTWORD, model_path);
    GetValue(model_revision, "model-revision", model_path);
    GetValue(vad_revision, "vad-revision", model_path);
    GetValue(punc_revision, "punc-revision", model_path);
    GetValue(itn_revision, "itn-revision", model_path);
    GetValue(lm_revision, "lm-revision", model_path);
    global_beam_ = global_beam.getValue();
    lattice_beam_ = lattice_beam.getValue();
    am_scale_ = am_scale.getValue();
    // Download model form Modelscope
    try {
      std::string s_download_model_dir = download_model_dir.getValue();
      std::string s_vad_path = model_path[VAD_DIR];
      std::string s_vad_quant = model_path[VAD_QUANT];
      std::string s_asr_path = model_path[MODEL_DIR];
      std::string s_asr_quant = model_path[QUANTIZE];
      std::string s_punc_path = model_path[PUNC_DIR];
      std::string s_punc_quant = model_path[PUNC_QUANT];
      std::string s_itn_path = model_path[ITN_DIR];
      std::string s_lm_path = model_path[LM_DIR];
      std::string python_cmd =
          "python -m funasr.download.runtime_sdk_download_tool --type onnx "
          "--quantize True ";
      if (vad_dir.isSet() && !s_vad_path.empty()) {
        std::string python_cmd_vad;
        std::string down_vad_path;
        std::string down_vad_model;
        if (access(s_vad_path.c_str(), F_OK) == 0) {
          // local
          python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
                           " --export-dir ./ " + " --model_revision " +
                           model_path["vad-revision"];
          down_vad_path = s_vad_path;
        } else {
          // modelscope
          LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: ";
          python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
                           " --export-dir " + s_download_model_dir +
                           " --model_revision " + model_path["vad-revision"];
          down_vad_path = s_download_model_dir + "/" + s_vad_path;
        }
        int ret = system(python_cmd_vad.c_str());
        if (ret != 0) {
          LOG(INFO) << "Failed to download model from modelscope. If you set "
                       "local vad model path, you can ignore the errors.";
        }
        down_vad_model = down_vad_path + "/model_quant.onnx";
        if (s_vad_quant == "false" || s_vad_quant == "False" ||
            s_vad_quant == "FALSE") {
          down_vad_model = down_vad_path + "/model.onnx";
        }
        if (access(down_vad_model.c_str(), F_OK) != 0) {
          LOG(ERROR) << down_vad_model << " do not exists.";
          exit(-1);
        } else {
          model_path[VAD_DIR] = down_vad_path;
          LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
        }
      } else {
        LOG(INFO) << "VAD model is not set, use default.";
      }
      if (model_dir.isSet() && !s_asr_path.empty()) {
        std::string python_cmd_asr;
        std::string down_asr_path;
        std::string down_asr_model;
        // modify model-revision by model name
        size_t found = s_asr_path.find(
            "speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-"
            "vocab8404");
        if (found != std::string::npos) {
          model_path["model-revision"] = "v1.2.4";
        }
        found = s_asr_path.find(
            "speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-"
            "vocab8404");
        if (found != std::string::npos) {
          model_path["model-revision"] = "v1.0.5";
        }
        found = s_asr_path.find(
            "speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
        if (found != std::string::npos) {
          model_path["model-revision"] = "v1.0.0";
          s_itn_path = "";
          s_lm_path = "";
        }
        if (access(s_asr_path.c_str(), F_OK) == 0) {
          // local
          python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
                           " --export-dir ./ " + " --model_revision " +
                           model_path["model-revision"];
          down_asr_path = s_asr_path;
        } else {
          // modelscope
          LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: ";
          python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
                           " --export-dir " + s_download_model_dir +
                           " --model_revision " + model_path["model-revision"];
          down_asr_path = s_download_model_dir + "/" + s_asr_path;
        }
        int ret = system(python_cmd_asr.c_str());
        if (ret != 0) {
          LOG(INFO) << "Failed to download model from modelscope. If you set "
                       "local asr model path, you can ignore the errors.";
        }
        down_asr_model = down_asr_path + "/model_quant.onnx";
        if (s_asr_quant == "false" || s_asr_quant == "False" ||
            s_asr_quant == "FALSE") {
          down_asr_model = down_asr_path + "/model.onnx";
        }
        if (access(down_asr_model.c_str(), F_OK) != 0) {
          LOG(ERROR) << down_asr_model << " do not exists.";
          exit(-1);
        } else {
          model_path[MODEL_DIR] = down_asr_path;
          LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
        }
      } else {
        LOG(INFO) << "ASR model is not set, use default.";
      }
      if (!s_itn_path.empty()) {
        std::string python_cmd_itn;
        std::string down_itn_path;
        std::string down_itn_model;
        if (access(s_itn_path.c_str(), F_OK) == 0) {
          // local
          python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
                           " --export-dir ./ " + " --model_revision " +
                           model_path["itn-revision"] + " --export False ";
          down_itn_path = s_itn_path;
        } else {
          // modelscope
          LOG(INFO) << "Download model: " << s_itn_path
                    << " from modelscope : ";
          python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
                           " --export-dir " + s_download_model_dir +
                           " --model_revision " + model_path["itn-revision"] +
                           " --export False ";
          down_itn_path = s_download_model_dir + "/" + s_itn_path;
        }
        int ret = system(python_cmd_itn.c_str());
        if (ret != 0) {
          LOG(INFO) << "Failed to download model from modelscope. If you set "
                       "local itn model path, you can ignore the errors.";
        }
        down_itn_model = down_itn_path + "/zh_itn_tagger.fst";
        if (access(down_itn_model.c_str(), F_OK) != 0) {
          LOG(ERROR) << down_itn_model << " do not exists.";
          exit(-1);
        } else {
          model_path[ITN_DIR] = down_itn_path;
          LOG(INFO) << "Set " << ITN_DIR << " : " << model_path[ITN_DIR];
        }
      } else {
        LOG(INFO) << "ITN model is not set, not executed.";
      }
      if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") {
        std::string python_cmd_lm;
        std::string down_lm_path;
        std::string down_lm_model;
        if (access(s_lm_path.c_str(), F_OK) == 0) {
          // local
          python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
                          " --export-dir ./ " + " --model_revision " +
                          model_path["lm-revision"] + " --export False ";
          down_lm_path = s_lm_path;
        } else {
          // modelscope
          LOG(INFO) << "Download model: " << s_lm_path << " from modelscope : ";
          python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
                          " --export-dir " + s_download_model_dir +
                          " --model_revision " + model_path["lm-revision"] +
                          " --export False ";
          down_lm_path = s_download_model_dir + "/" + s_lm_path;
        }
        int ret = system(python_cmd_lm.c_str());
        if (ret != 0) {
          LOG(INFO) << "Failed to download model from modelscope. If you set "
                       "local lm model path, you can ignore the errors.";
        }
        down_lm_model = down_lm_path + "/TLG.fst";
        if (access(down_lm_model.c_str(), F_OK) != 0) {
          LOG(ERROR) << down_lm_model << " do not exists.";
          exit(-1);
        } else {
          model_path[LM_DIR] = down_lm_path;
          LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR];
        }
      } else {
        LOG(INFO) << "LM model is not set, not executed.";
        model_path[LM_DIR] = "";
      }
      if (punc_dir.isSet() && !s_punc_path.empty()) {
        std::string python_cmd_punc;
        std::string down_punc_path;
        std::string down_punc_model;
        if (access(s_punc_path.c_str(), F_OK) == 0) {
          // local
          python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
                            " --export-dir ./ " + " --model_revision " +
                            model_path["punc-revision"];
          down_punc_path = s_punc_path;
        } else {
          // modelscope
          LOG(INFO) << "Download model: " << s_punc_path
                    << " from modelscope: ";
          python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
                            " --export-dir " + s_download_model_dir +
                            " --model_revision " + model_path["punc-revision"];
          down_punc_path = s_download_model_dir + "/" + s_punc_path;
        }
        int ret = system(python_cmd_punc.c_str());
        if (ret != 0) {
          LOG(INFO) << "Failed to download model from modelscope. If you set "
                       "local punc model path, you can ignore the errors.";
        }
        down_punc_model = down_punc_path + "/model_quant.onnx";
        if (s_punc_quant == "false" || s_punc_quant == "False" ||
            s_punc_quant == "FALSE") {
          down_punc_model = down_punc_path + "/model.onnx";
        }
        if (access(down_punc_model.c_str(), F_OK) != 0) {
          LOG(ERROR) << down_punc_model << " do not exists.";
          exit(-1);
        } else {
          model_path[PUNC_DIR] = down_punc_path;
          LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
        }
      } else {
        LOG(INFO) << "PUNC model is not set, use default.";
      }
    } catch (std::exception const &e) {
      LOG(ERROR) << "Error: " << e.what();
    }
    std::string s_listen_ip = listen_ip.getValue();
    int s_port = port.getValue();
    int s_io_thread_num = io_thread_num.getValue();
    int s_decoder_thread_num = decoder_thread_num.getValue();
    int s_model_thread_num = model_thread_num.getValue();
    asio::io_context io_decoder;  // context for decoding
    std::vector<std::thread> decoder_threads;
    // hotword file
    std::string hotword_path;
    hotword_path = model_path.at(HOTWORD);
    fst_inc_wts_ = fst_inc_wts.getValue();
    LOG(INFO) << "hotword path: " << hotword_path;
    funasr::ExtractHws(hotword_path, hws_map_);
    auto conn_guard = asio::make_work_guard(
        io_decoder);  // make sure threads can wait in the queue
    // create threads pool
    for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
      decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
    }
    // ModelDecoderSrv modelSrv(
    //     io_decoder);  // websocket server for asr engine
    // modelSrv.initAsr(model_path, s_model_thread_num);  // init asr model
    // FUNASR_HANDLE asr_handle= initAsr();
    LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
    LOG(INFO) << "io-thread-num: " << s_io_thread_num;
    LOG(INFO) << "model-thread-num: " << s_model_thread_num;
    http::server2::server s(s_listen_ip, std::to_string(s_port), "./",
                            s_io_thread_num, io_decoder, model_path,
                            s_model_thread_num);
    s.run();
    LOG(INFO) << "http model loop " << s_port;
    // wait for theads
    for (auto &t : decoder_threads) {
      t.join();
    }
  } catch (std::exception const &e) {
    LOG(ERROR) << "Error: " << e.what();
  }
  return 0;
}
runtime/http/bin/funasr-http-main.hpp
New file
@@ -0,0 +1,20 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
#ifndef HTTP_SERVER2_MAIN_HPP
#define HTTP_SERVER2_MAIN_HPP
#include "model-decoder.h"
#include "server.hpp"
namespace http {
namespace server2 {
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_MAIN_HPP
runtime/http/bin/header.hpp
New file
@@ -0,0 +1,27 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
//
// header.hpp
// copy some codes from  http://www.boost.org/
#ifndef HTTP_SERVER2_HEADER_HPP
#define HTTP_SERVER2_HEADER_HPP
#include <string>
namespace http {
namespace server2 {
struct header
{
  std::string name;
  std::string value;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_HEADER_HPP
runtime/http/bin/io_context_pool.cpp
New file
@@ -0,0 +1,66 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
//
// io_context_pool.cpp
// ~~~~~~~~~~~~~~~~~~~
// copy some codes from  http://www.boost.org/
#include "io_context_pool.hpp"
#include <stdexcept>
#include <thread>
namespace http {
namespace server2 {
io_context_pool::io_context_pool(std::size_t pool_size)
  : next_io_context_(0)
{
  if (pool_size == 0)
    throw std::runtime_error("io_context_pool size is 0");
  // Give all the io_contexts work to do so that their run() functions will not
  // exit until they are explicitly stopped.
  for (std::size_t i = 0; i < pool_size; ++i)
  {
    io_context_ptr io_context(new asio::io_context);
    io_contexts_.push_back(io_context);
    work_.push_back(asio::make_work_guard(*io_context));
  }
}
void io_context_pool::run()
{
  // Create a pool of threads to run all of the io_contexts.
  std::vector<std::thread> threads;
  for (std::size_t i = 0; i < io_contexts_.size(); ++i)
    threads.emplace_back([this, i]{ io_contexts_[i]->run(); });
  // Wait for all threads in the pool to exit.
  for (std::size_t i = 0; i < threads.size(); ++i)
    threads[i].join();
}
void io_context_pool::stop()
{
  // Explicitly stop all io_contexts.
  for (std::size_t i = 0; i < io_contexts_.size(); ++i)
    io_contexts_[i]->stop();
}
asio::io_context& io_context_pool::get_io_context()
{
  // Use a round-robin scheme to choose the next io_context to use.
  asio::io_context& io_context = *io_contexts_[next_io_context_];
  ++next_io_context_;
  if (next_io_context_ == io_contexts_.size())
    next_io_context_ = 0;
  return io_context;
}
} // namespace server2
} // namespace http
runtime/http/bin/io_context_pool.hpp
New file
@@ -0,0 +1,59 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
//
// io_context_pool.hpp
// ~~~~~~~~~~~~~~~~~~~
// copy some codes from  http://www.boost.org/
#ifndef HTTP_SERVER2_IO_SERVICE_POOL_HPP
#define HTTP_SERVER2_IO_SERVICE_POOL_HPP
#include <asio.hpp>
#include <list>
#include <memory>
#include <vector>
namespace http {
namespace server2 {
/// A pool of io_context objects.
class io_context_pool
{
public:
  /// Construct the io_context pool.
  explicit io_context_pool(std::size_t pool_size);
  /// Run all io_context objects in the pool.
  void run();
  /// Stop all io_context objects in the pool.
  void stop();
  /// Get an io_context to use.
  asio::io_context& get_io_context();
private:
  io_context_pool(const io_context_pool&) = delete;
  io_context_pool& operator=(const io_context_pool&) = delete;
  typedef std::shared_ptr<::asio::io_context> io_context_ptr;
  typedef asio::executor_work_guard<
    asio::io_context::executor_type> io_context_work;
  /// The pool of io_contexts.
  std::vector<io_context_ptr> io_contexts_;
  /// The work that keeps the io_contexts running.
  std::list<io_context_work> work_;
  /// The next io_context to use for a connection.
  std::size_t next_io_context_;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_IO_SERVICE_POOL_HPP
runtime/http/bin/model-decoder.cpp
New file
@@ -0,0 +1,119 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
// funasr asr engine
#include "model-decoder.h"
#include <thread>
#include <utility>
#include <vector>
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
// feed msg to asr engine for decoder
void ModelDecoder::do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg) {
  try {
    //   std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
    if (session_msg->status == 1) return;
    //std::cout << "in do_decoder" << std::endl;
    std::shared_ptr<std::vector<char>> buffer = session_msg->samples;
    int num_samples = buffer->size();  // the size of the buf
    std::string wav_name =session_msg->msg["wav_name"];
    bool itn = session_msg->msg["itn"];
    int audio_fs = session_msg->msg["audio_fs"];;
    std::string wav_format = session_msg->msg["wav_format"];
    if (num_samples > 0 && session_msg->hotwords_embedding->size() > 0) {
      std::string asr_result = "";
      std::string stamp_res = "";
      std::string stamp_sents = "";
      try {
        std::vector<std::vector<float>> hotwords_embedding_(
            *(session_msg->hotwords_embedding));
        FUNASR_RESULT Result = FunOfflineInferBuffer(
            asr_handle, buffer->data(), buffer->size(), RASR_NONE, nullptr,
            std::move(hotwords_embedding_), audio_fs, wav_format, itn,
            session_msg->decoder_handle);
        if (Result != nullptr) {
          asr_result = FunASRGetResult(Result, 0);  // get decode result
          stamp_res = FunASRGetStamp(Result);
          stamp_sents = FunASRGetStampSents(Result);
          FunASRFreeResult(Result);
        } else {
          std::this_thread::sleep_for(std::chrono::milliseconds(20));
        }
      } catch (std::exception const &e) {
        std::cout << "error in decoder!!! "<<e.what()  <<std::endl;
      }
      nlohmann::json jsonresult;        // result json
      jsonresult["text"] = asr_result;  // put result in 'text'
      jsonresult["mode"] = "offline";
      jsonresult["is_final"] = false;
      if (stamp_res != "") {
        jsonresult["timestamp"] = stamp_res;
      }
      if (stamp_sents != "") {
        try {
          nlohmann::json json_stamp = nlohmann::json::parse(stamp_sents);
          jsonresult["stamp_sents"] = json_stamp;
        } catch (std::exception const &e) {
          std::cout << "error:" << e.what();
          jsonresult["stamp_sents"] = "";
        }
      }
      jsonresult["wav_name"] = wav_name;
      std::cout << "buffer.size=" << buffer->size()
                << ",result json=" << jsonresult.dump() << std::endl;
      FunWfstDecoderUnloadHwsRes(session_msg->decoder_handle);
      FunASRWfstDecoderUninit(session_msg->decoder_handle);
      session_msg->status = 1;
      session_msg->msg["asr_result"] = jsonresult;
      return;
    } else {
      std::cout << "Sent empty msg";
      nlohmann::json jsonresult;  // result json
      jsonresult["text"] = "";    // put result in 'text'
      jsonresult["mode"] = "offline";
      jsonresult["is_final"] = false;
      jsonresult["wav_name"] = wav_name;
    }
  } catch (std::exception const &e) {
    std::cerr << "Error: " << e.what() << std::endl;
  }
}
// init asr model
FUNASR_HANDLE ModelDecoder::initAsr(std::map<std::string, std::string> &model_path,
                           int thread_num) {
  try {
    // init model with api
    asr_handle = FunOfflineInit(model_path, thread_num);
    LOG(INFO) << "model successfully inited";
    return asr_handle;
  } catch (const std::exception &e) {
    LOG(INFO) << e.what();
    return nullptr;
  }
}
runtime/http/bin/model-decoder.h
New file
@@ -0,0 +1,60 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
// funasr asr engine
#ifndef MODEL_DECODER_SERVER_H_
#define MODEL_DECODER_SERVER_H_
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <utility>
#define ASIO_STANDALONE 1  // not boost
#include <glog/logging.h>
#include <fstream>
#include <functional>
#include "asio.hpp"
#include "asr_sessions.h"
#include "com-define.h"
#include "funasrruntime.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
#include "util/text-utils.h"
class ModelDecoder {
 public:
  ModelDecoder(asio::io_context &io_decoder,
               std::map<std::string, std::string> &model_path, int thread_num)
      : io_decoder_(io_decoder) {
    asr_handle = initAsr(model_path, thread_num);
  }
  void do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg);
  FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path, int thread_num);
  asio::io_context &io_decoder_;  // threads for asr decoder
  FUNASR_HANDLE get_asr_handle()
  {
    return asr_handle;
  }
 private:
  FUNASR_HANDLE asr_handle;  // asr engine handle
  bool isonline = false;  // online or offline engine, now only support offline
};
#endif  // MODEL_DECODER_SERVER_H_
runtime/http/bin/reply.cpp
New file
@@ -0,0 +1,245 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
// reply.cpp
// ~~~~~~~~~
//
// copy some codes from  http://www.boost.org/
#include "reply.hpp"
#include <iostream>
#include <string>
namespace http {
namespace server2 {
namespace status_strings {
const std::string ok = "HTTP/1.0 200 OK\r\n";
const std::string created = "HTTP/1.0 201 Created\r\n";
const std::string accepted = "HTTP/1.0 202 Accepted\r\n";
const std::string no_content = "HTTP/1.0 204 No Content\r\n";
const std::string multiple_choices = "HTTP/1.0 300 Multiple Choices\r\n";
const std::string moved_permanently = "HTTP/1.0 301 Moved Permanently\r\n";
const std::string moved_temporarily = "HTTP/1.0 302 Moved Temporarily\r\n";
const std::string not_modified = "HTTP/1.0 304 Not Modified\r\n";
const std::string bad_request = "HTTP/1.0 400 Bad Request\r\n";
const std::string unauthorized = "HTTP/1.0 401 Unauthorized\r\n";
const std::string forbidden = "HTTP/1.0 403 Forbidden\r\n";
const std::string not_found = "HTTP/1.0 404 Not Found\r\n";
const std::string internal_server_error =
    "HTTP/1.0 500 Internal Server Error\r\n";
const std::string not_implemented = "HTTP/1.0 501 Not Implemented\r\n";
const std::string bad_gateway = "HTTP/1.0 502 Bad Gateway\r\n";
const std::string service_unavailable = "HTTP/1.0 503 Service Unavailable\r\n";
asio::const_buffer to_buffer(reply::status_type status) {
  switch (status) {
    case reply::ok:
      return asio::buffer(ok);
    case reply::created:
      return asio::buffer(created);
    case reply::accepted:
      return asio::buffer(accepted);
    case reply::no_content:
      return asio::buffer(no_content);
    case reply::multiple_choices:
      return asio::buffer(multiple_choices);
    case reply::moved_permanently:
      return asio::buffer(moved_permanently);
    case reply::moved_temporarily:
      return asio::buffer(moved_temporarily);
    case reply::not_modified:
      return asio::buffer(not_modified);
    case reply::bad_request:
      return asio::buffer(bad_request);
    case reply::unauthorized:
      return asio::buffer(unauthorized);
    case reply::forbidden:
      return asio::buffer(forbidden);
    case reply::not_found:
      return asio::buffer(not_found);
    case reply::internal_server_error:
      return asio::buffer(internal_server_error);
    case reply::not_implemented:
      return asio::buffer(not_implemented);
    case reply::bad_gateway:
      return asio::buffer(bad_gateway);
    case reply::service_unavailable:
      return asio::buffer(service_unavailable);
    default:
      return asio::buffer(internal_server_error);
  }
}
}  // namespace status_strings
namespace misc_strings {
const char name_value_separator[] = {':', ' '};
const char crlf[] = {'\r', '\n'};
}  // namespace misc_strings
std::vector<::asio::const_buffer> reply::to_buffers() {
  std::vector<::asio::const_buffer> buffers;
  buffers.push_back(status_strings::to_buffer(status));
  for (std::size_t i = 0; i < headers.size(); ++i) {
    header &h = headers[i];
    buffers.push_back(asio::buffer(h.name));
    buffers.push_back(asio::buffer(misc_strings::name_value_separator));
    buffers.push_back(asio::buffer(h.value));
    buffers.push_back(asio::buffer(misc_strings::crlf));
  }
  buffers.push_back(asio::buffer(misc_strings::crlf));
  buffers.push_back(asio::buffer(content));
  return buffers;
}
namespace stock_replies {
const char ok[] = "";
const char created[] =
    "<html>"
    "<head><title>Created</title></head>"
    "<body><h1>201 Created</h1></body>"
    "</html>";
const char accepted[] =
    "<html>"
    "<head><title>Accepted</title></head>"
    "<body><h1>202 Accepted</h1></body>"
    "</html>";
const char no_content[] =
    "<html>"
    "<head><title>No Content</title></head>"
    "<body><h1>204 Content</h1></body>"
    "</html>";
const char multiple_choices[] =
    "<html>"
    "<head><title>Multiple Choices</title></head>"
    "<body><h1>300 Multiple Choices</h1></body>"
    "</html>";
const char moved_permanently[] =
    "<html>"
    "<head><title>Moved Permanently</title></head>"
    "<body><h1>301 Moved Permanently</h1></body>"
    "</html>";
const char moved_temporarily[] =
    "<html>"
    "<head><title>Moved Temporarily</title></head>"
    "<body><h1>302 Moved Temporarily</h1></body>"
    "</html>";
const char not_modified[] =
    "<html>"
    "<head><title>Not Modified</title></head>"
    "<body><h1>304 Not Modified</h1></body>"
    "</html>";
const char bad_request[] =
    "<html>"
    "<head><title>Bad Request</title></head>"
    "<body><h1>400 Bad Request</h1></body>"
    "</html>";
const char unauthorized[] =
    "<html>"
    "<head><title>Unauthorized</title></head>"
    "<body><h1>401 Unauthorized</h1></body>"
    "</html>";
const char forbidden[] =
    "<html>"
    "<head><title>Forbidden</title></head>"
    "<body><h1>403 Forbidden</h1></body>"
    "</html>";
const char not_found[] =
    "<html>"
    "<head><title>Not Found</title></head>"
    "<body><h1>404 Not Found</h1></body>"
    "</html>";
const char internal_server_error[] =
    "<html>"
    "<head><title>Internal Server Error</title></head>"
    "<body><h1>500 Internal Server Error</h1></body>"
    "</html>";
const char not_implemented[] =
    "<html>"
    "<head><title>Not Implemented</title></head>"
    "<body><h1>501 Not Implemented</h1></body>"
    "</html>";
const char bad_gateway[] =
    "<html>"
    "<head><title>Bad Gateway</title></head>"
    "<body><h1>502 Bad Gateway</h1></body>"
    "</html>";
const char service_unavailable[] =
    "<html>"
    "<head><title>Service Unavailable</title></head>"
    "<body><h1>503 Service Unavailable</h1></body>"
    "</html>";
std::string to_string(reply::status_type status) {
  switch (status) {
    case reply::ok:
      return ok;
    case reply::created:
      return created;
    case reply::accepted:
      return accepted;
    case reply::no_content:
      return no_content;
    case reply::multiple_choices:
      return multiple_choices;
    case reply::moved_permanently:
      return moved_permanently;
    case reply::moved_temporarily:
      return moved_temporarily;
    case reply::not_modified:
      return not_modified;
    case reply::bad_request:
      return bad_request;
    case reply::unauthorized:
      return unauthorized;
    case reply::forbidden:
      return forbidden;
    case reply::not_found:
      return not_found;
    case reply::internal_server_error:
      return internal_server_error;
    case reply::not_implemented:
      return not_implemented;
    case reply::bad_gateway:
      return bad_gateway;
    case reply::service_unavailable:
      return service_unavailable;
    default:
      return internal_server_error;
  }
}
}  // namespace stock_replies
reply reply::stock_reply(std::string jsonresult) {
  reply rep;
  rep.status = reply::ok;
  rep.content = jsonresult+"\n";
  rep.headers.resize(2);
  rep.headers[0].name = "Content-Length";
  rep.headers[0].value = std::to_string(rep.content.size());
  rep.headers[1].name = "Content-Type";
  rep.headers[1].value = "text/html;charset=utf-8";
  return rep;
}
reply reply::stock_reply(reply::status_type status) {
  reply rep;
  rep.status = status;
  rep.content = stock_replies::to_string(status);
  rep.headers.resize(2);
  rep.headers[0].name = "Content-Length";
  rep.headers[0].value = std::to_string(rep.content.size());
  rep.headers[1].name = "Content-Type";
  rep.headers[1].value = "text/html";
  return rep;
}
}  // namespace server2
}  // namespace http
runtime/http/bin/reply.hpp
New file
@@ -0,0 +1,64 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
// reply.hpp
// ~~~~~~~~~
//
// copy some codes from  http://www.boost.org/
#ifndef HTTP_SERVER2_REPLY_HPP
#define HTTP_SERVER2_REPLY_HPP
#include <asio.hpp>
#include <string>
#include <vector>
#include "header.hpp"
namespace http {
namespace server2 {
/// A reply to be sent to a client.
struct reply {
  /// The status of the reply.
  enum status_type {
    ok = 200,
    created = 201,
    accepted = 202,
    no_content = 204,
    multiple_choices = 300,
    moved_permanently = 301,
    moved_temporarily = 302,
    not_modified = 304,
    bad_request = 400,
    unauthorized = 401,
    forbidden = 403,
    not_found = 404,
    internal_server_error = 500,
    not_implemented = 501,
    bad_gateway = 502,
    service_unavailable = 503
  } status;
  /// The headers to be included in the reply.
  std::vector<header> headers;
  /// The content to be sent in the reply.
  std::string content;
  /// Convert the reply into a vector of buffers. The buffers do not own the
  /// underlying memory blocks, therefore the reply object must remain valid and
  /// not be changed until the write operation has completed.
  std::vector<::asio::const_buffer> to_buffers();
  /// Get a stock reply.
  static reply stock_reply(status_type status);
  static reply stock_reply(std::string jsonresult);
};
}  // namespace server2
}  // namespace http
#endif  // HTTP_SERVER2_REPLY_HPP
runtime/http/bin/server.cpp
New file
@@ -0,0 +1,113 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
//
// server.cpp
// copy some codes from  http://www.boost.org/
#include "server.hpp"
#include <signal.h>
#include <fstream>
#include <iostream>
#include <utility>
#include "util.h"
namespace http {
namespace server2 {
server::server(const std::string &address, const std::string &port,
               const std::string &doc_root, std::size_t io_context_pool_size,
               asio::io_context &decoder_context,
               std::map<std::string, std::string> &model_path, int thread_num)
    : io_context_pool_(io_context_pool_size),
      signals_(io_context_pool_.get_io_context()),
      acceptor_(io_context_pool_.get_io_context()),
      decoder_context(decoder_context) {
  // Register to handle the signals that indicate when the server should exit.
  // It is safe to register for the same signal multiple times in a program,
  // provided all registration for the specified signal is made through Asio.
  try {
    model_decoder =
        std::make_shared<ModelDecoder>(decoder_context, model_path, thread_num);
    LOG(INFO) << "try to listen on port:" << port << std::endl;
    LOG(INFO) << "still not work, pls wait... " << std::endl;
    LOG(INFO) << "if always waiting here, may be port in used, pls change the "
                 "port or kill pre-process!"
              << std::endl;
    atom_id = 0;
    // init model with api
    signals_.add(SIGINT);
    signals_.add(SIGTERM);
#if defined(SIGQUIT)
    signals_.add(SIGQUIT);
#endif  // defined(SIGQUIT)
    do_await_stop();
    // Open the acceptor with the option to reuse the address (i.e.
    // SO_REUSEADDR).
    asio::ip::tcp::resolver resolver(acceptor_.get_executor());
    asio::ip::tcp::endpoint endpoint = *resolver.resolve(address, port).begin();
    acceptor_.open(endpoint.protocol());
    acceptor_.set_option(asio::ip::tcp::acceptor::reuse_address(true));
    acceptor_.bind(endpoint);
    acceptor_.listen();
    do_accept();
    std::cout << "use curl to test,just as " << std::endl;
    std::cout << "curl -F \"file=@example.wav\" 127.0.0.1:80" << std::endl;
    std::cout << "http post only support offline mode, if you want online "
                 "mode, pls try websocket!"
              << std::endl;
    std::cout << "now succeed listen on port " << address << ":" << port
              << ", can accept data now!!!" << std::endl;
  } catch (const std::exception &e) {
    std::cout << "error:" << e.what();
  }
}
void server::run() { io_context_pool_.run(); }
void server::do_accept() {
  acceptor_.async_accept(
      io_context_pool_.get_io_context(),
      [this](asio::error_code ec, asio::ip::tcp::socket socket) {
        // Check whether the server was stopped by a signal before this
        // completion handler had a chance to run.
        if (!acceptor_.is_open()) {
          return;
        }
        if (!ec) {
          std::lock_guard<std::mutex> lk(m_lock);
          atom_id = atom_id + 1;
          std::make_shared<connection>(std::move(socket), decoder_context,
                                       (atom_id).load(), model_decoder)
              ->start();
        }
        do_accept();
      });
}
void server::do_await_stop() {
  signals_.async_wait([this](asio::error_code /*ec*/, int /*signo*/) {
    io_context_pool_.stop();
  });
}
}  // namespace server2
}  // namespace http
runtime/http/bin/server.hpp
New file
@@ -0,0 +1,71 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
//
// server.hpp
// ~~~~~~~~~~
// copy some codes from  http://www.boost.org/
#ifndef HTTP_SERVER2_SERVER_HPP
#define HTTP_SERVER2_SERVER_HPP
#include <asio.hpp>
#include <atomic>
#include <string>
#include "connection.hpp"
#include "funasrruntime.h"
#include "io_context_pool.hpp"
#include "model-decoder.h"
#include "util.h"
namespace http {
namespace server2 {
/// The top-level class of the HTTP server.
class server {
 public:
  server(const server &) = delete;
  server &operator=(const server &) = delete;
  /// Construct the server to listen on the specified TCP address and port, and
  /// serve up files from the given directory.
  explicit server(const std::string &address, const std::string &port,
                  const std::string &doc_root, std::size_t io_context_pool_size,
                  asio::io_context &decoder_context,
                  std::map<std::string, std::string> &model_path,
                  int thread_num);
  /// Run the server's io_context loop.
  void run();
 private:
  /// Perform an asynchronous accept operation.
  void do_accept();
  /// Wait for a request to stop the server.
  void do_await_stop();
  /// The pool of io_context objects used to perform asynchronous operations.
  io_context_pool io_context_pool_;
  asio::io_context &decoder_context;
  /// The signal_set is used to register for process termination notifications.
  asio::signal_set signals_;
  /// Acceptor used to listen for incoming connections.
  asio::ip::tcp::acceptor acceptor_;
  std::shared_ptr<ModelDecoder> model_decoder;
  std::atomic<int> atom_id;
  std::mutex m_lock;
};
}  // namespace server2
}  // namespace http
#endif  // HTTP_SERVER2_SERVER_HPP
runtime/http/readme.md
New file
@@ -0,0 +1,58 @@
# Advanced Development Guide (File transcription service) ([click](../docs/SDK_advanced_guide_offline.md))
# Real-time Speech Transcription Service Development Guide ([click](../docs/SDK_advanced_guide_online.md))
# If you want to compile the file yourself, you can follow the steps below.
## Building for Linux/Unix
### Download onnxruntime
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
```
### Download ffmpeg
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
```
### Install deps
```shell
# openblas
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
# openssl
apt-get install libssl-dev #ubuntu
# yum install openssl-devel #centos
```
### Build runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
mkdir build && cd build
cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
make -j 4
```
### test
```shell
curl -F \"file=@example.wav\" 127.0.0.1:80
```
### run
```shell
./funasr-http-server  \
  --lm-dir '' \
  --itn-dir '' \
  --download-model-dir ${download_model_dir} \
  --model-dir ${model_dir} \
  --vad-dir ${vad_dir} \
  --punc-dir ${punc_dir} \
  --decoder-thread-num ${decoder_thread_num} \
  --io-thread-num  ${io_thread_num} \
  --port ${port} \
```
runtime/http/readme_zh.md
New file
@@ -0,0 +1,61 @@
# FunASR离线文件转写服务开发指南([点击此处](../docs/SDK_advanced_guide_offline_zh.md))
# FunASR实时语音听写服务开发指南([点击此处](../docs/SDK_advanced_guide_online_zh.md))
# 如果您想自己编译文件,可以参考下述步骤
## Linux/Unix 平台编译
### 下载 onnxruntime
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
```
### 下载 ffmpeg
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
```
### 安装依赖
```shell
# openblas
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
# openssl
apt-get install libssl-dev #ubuntu
# yum install openssl-devel #centos
```
### 编译 runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
mkdir build && cd build
cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
make -j 4
```
### 测试
```shell
curl -F \"file=@example.wav\" 127.0.0.1:80
```
### 运行
```shell
./funasr-http-server  \
  --lm-dir '' \
  --itn-dir '' \
  --download-model-dir ${download_model_dir} \
  --model-dir ${model_dir} \
  --vad-dir ${vad_dir} \
  --punc-dir ${punc_dir} \
  --decoder-thread-num ${decoder_thread_num} \
  --io-thread-num  ${io_thread_num} \
  --port ${port} \
```
runtime/http/requirements_install.md
New file
@@ -0,0 +1,15 @@
#### Download onnxruntime
```shell
bash third_party/download_onnxruntime.sh
```
#### Download ffmpeg
```shell
bash third_party/download_ffmpeg.sh
```
#### Install openblas and openssl
```shell
sudo apt-get install libopenblas-dev libssl-dev #ubuntu
# sudo yum -y install openblas-devel openssl-devel #centos
```
runtime/onnxruntime/CMakeLists.txt
@@ -4,6 +4,7 @@
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(GPU "Whether to build with GPU" OFF)
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
@@ -49,6 +50,17 @@
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
if(GPU)
    add_definitions(-DUSE_GPU)
    set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
    set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
    include_directories(${TORCH_DIR}/include)
    include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
    link_directories(${TORCH_DIR}/lib)
    link_directories(${TORCH_BLADE_DIR})
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
if(ENABLE_GLOG)
    include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
    set(BUILD_TESTING OFF)
runtime/onnxruntime/bin/CMakeLists.txt
@@ -10,33 +10,43 @@
endif()
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
# include_directories(${FFMPEG_DIR}/include)
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -52,7 +52,7 @@
    std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
    
    // warm up
    for (size_t i = 0; i < 1; i++)
    for (size_t i = 0; i < 10; i++)
    {
        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
        if(result){
@@ -127,6 +127,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -140,11 +141,14 @@
    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
    TCLAP::ValueArg<std::int32_t>   audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
    TCLAP::ValueArg<std::string>    hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(bladedisc);
    cmd.add(vad_dir);
    cmd.add(vad_quant);
    cmd.add(punc_dir);
@@ -159,11 +163,14 @@
    cmd.add(wav_path);
    cmd.add(audio_fs);
    cmd.add(thread_num);
    cmd.add(use_gpu);
    cmd.add(batch_size);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(bladedisc, BLADEDISC, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
@@ -175,7 +182,9 @@
    struct timeval start, end;
    gettimeofday(&start, nullptr);
    FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1);
    bool use_gpu_ = use_gpu.getValue();
    int batch_size_ = batch_size.getValue();
    FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1, use_gpu_, batch_size_);
    if (!asr_handle)
    {
runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -19,6 +19,7 @@
#include "com-define.h"
#include <unordered_map>
#include "util.h"
#include "audio.h"
using namespace std;
bool is_target_file(const std::string& filename, const std::string target) {
@@ -44,6 +45,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -57,9 +59,12 @@
    TCLAP::ValueArg<std::string>    wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
    TCLAP::ValueArg<std::int32_t>   audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
    TCLAP::ValueArg<std::string>    hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(bladedisc);
    cmd.add(vad_dir);
    cmd.add(vad_quant);
    cmd.add(punc_dir);
@@ -73,11 +78,14 @@
    cmd.add(wav_path);
    cmd.add(audio_fs);
    cmd.add(hotword);
    cmd.add(use_gpu);
    cmd.add(batch_size);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(bladedisc, BLADEDISC, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
@@ -89,7 +97,9 @@
    struct timeval start, end;
    gettimeofday(&start, nullptr);
    int thread_num = 1;
    FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
    bool use_gpu_ = use_gpu.getValue();
    int batch_size_ = batch_size.getValue();
    FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num, use_gpu_, batch_size_);
    if (!asr_hanlde)
    {
@@ -156,7 +166,33 @@
    for (int i = 0; i < wav_list.size(); i++) {
        auto& wav_file = wav_list[i];
        auto& wav_id = wav_ids[i];
        gettimeofday(&start, nullptr);
        // For debug:begin
        // int32_t sampling_rate_ = audio_fs.getValue();
        // funasr::Audio audio(1);
        // if(is_target_file(wav_file.c_str(), "wav")){
        //     if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
        //         LOG(ERROR)<<"Failed to load "<< wav_file;
        //         exit(-1);
        //     }
        // }else if(is_target_file(wav_file.c_str(), "pcm")){
        //     if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
        //         LOG(ERROR)<<"Failed to load "<< wav_file;
        //         exit(-1);
        //     }
        // }else{
        //     if (!audio.FfmpegLoad(wav_file.c_str(), true)){
        //         LOG(ERROR)<<"Failed to load "<< wav_file;
        //         exit(-1);
        //     }
        // }
        // char* speech_buff = audio.GetSpeechChar();
        // int buff_len = audio.GetSpeechLen()*2;
        // gettimeofday(&start, nullptr);
        // FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle);
        // For debug:end
        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
        gettimeofday(&end, nullptr);
        seconds = (end.tv_sec - start.tv_sec);
runtime/onnxruntime/include/audio.h
@@ -83,9 +83,11 @@
    int FetchTpass(AudioFrame *&frame);
    int Fetch(float *&dout, int &len, int &flag);
    int Fetch(float *&dout, int &len, int &flag, float &start_time);
    int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
    int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
    void Padding();
    void Split(OfflineStream* offline_streamj);
    void CutSplit(OfflineStream* offline_streamj);
    void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
    void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
    void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
    float GetTimeLen();
runtime/onnxruntime/include/com-define.h
@@ -51,6 +51,15 @@
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "am.mvn"
#define VAD_CONFIG_NAME "config.yaml"
// gpu models
#define INFER_GPU "gpu"
#define BATCHSIZE "batch-size"
#define TORCH_MODEL_NAME "model.torchscripts"
#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
#define BLADE_MODEL_NAME "model.blade.fp16.pt"
#define BLADEDISC "bladedisc"
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define LM_CONFIG_NAME "config.yaml"
runtime/onnxruntime/include/funasrruntime.h
@@ -96,7 +96,7 @@
_FUNASRAPI void                    CTTransformerUninit(FUNASR_HANDLE handle);
//OfflineStream
_FUNASRAPI FUNASR_HANDLE      FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE      FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
_FUNASRAPI void             FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr);
// buffer
_FUNASRAPI FUNASR_RESULT    FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, 
@@ -106,9 +106,9 @@
_FUNASRAPI FUNASR_RESULT    FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, 
                                            QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, 
                                            int sampling_rate=16000, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
#if !defined(__APPLE__)
//#if !defined(__APPLE__)
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode=ASR_OFFLINE);
#endif
//#endif
_FUNASRAPI void                FunOfflineUninit(FUNASR_HANDLE handle);
runtime/onnxruntime/include/model.h
@@ -5,6 +5,10 @@
#include <string>
#include <map>
#include "funasrruntime.h"
#include "vocab.h"
#include "phone-set.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
namespace funasr {
class Model {
  public:
@@ -18,13 +22,19 @@
    virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
    virtual void InitFstDecoder(){};
    virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
    virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1)
      {return std::vector<string>();};
    virtual std::string Rescoring() = 0;
    virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
    virtual void InitSegDict(const std::string &seg_dict_model){};
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
    virtual std::string GetLang(){return "";};
    virtual int GetAsrSampleRate() = 0;
    virtual void SetBatchSize(int batch_size) {};
    virtual int GetBatchSize() {return 0;};
    virtual Vocab* GetVocab() {return nullptr;};
    virtual Vocab* GetLmVocab() {return nullptr;};
    virtual PhoneSet* GetPhoneSet() {return nullptr;};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
runtime/onnxruntime/include/offline-stream.h
@@ -14,7 +14,7 @@
namespace funasr {
class OfflineStream {
  public:
    OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
    OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
    ~OfflineStream(){};
    std::unique_ptr<VadModel> vad_handle= nullptr;
@@ -33,6 +33,6 @@
    bool use_itn=false;
};
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1);
} // namespace funasr
#endif
runtime/onnxruntime/src/CMakeLists.txt
@@ -1,6 +1,15 @@
file(GLOB files1 "*.cpp")
if(APPLE)
    file(GLOB itn_files "itn-*.cpp")
    list(REMOVE_ITEM files1 ${itn_files})
endif(APPLE)
list(REMOVE_ITEM files1 "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
set(files ${files1})
if(GPU)
    set(files ${files} "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
endif()
message("files: "${files})
@@ -23,9 +32,17 @@
    set(EXTRA_LIBS pthread yaml-cpp csrc kaldi-decoder fst glog gflags avutil avcodec avformat swresample)
    include_directories(${ONNXRUNTIME_DIR}/include)
    include_directories(${FFMPEG_DIR}/include)
    if(APPLE)
        target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
        target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
    endif(APPLE)
endif()
if(GPU)
    set(TORCH_DEPS torch torch_cuda torch_cpu c10 c10_cuda torch_blade ral_base_context)
endif()
#message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
include_directories(${CMAKE_SOURCE_DIR}/include)
include_directories(${CMAKE_SOURCE_DIR}/third_party)
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS} ${TORCH_DEPS})
runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,90 @@
    }
}
int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    batch_in = std::min((int)frame_queue.size(), batch_size);
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_queue.front();
            frame_queue.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    //compute batch size
    queue<AudioFrame *> frame_batch;
    int max_acc = 300*1000*seg_sample;
    int max_sent = 60*1000*seg_sample;
    int bs_acc = 0;
    int max_len = 0;
    int max_batch = 1;
    #ifdef USE_GPU
        max_batch = batch_size;
    #endif
    max_batch = std::min(max_batch, (int)frame_queue.size());
    for(int idx=0; idx < max_batch; idx++){
        AudioFrame *frame = frame_queue.front();
        int length = frame->GetLen();
        if(length >= max_sent){
            if(bs_acc == 0){
                bs_acc++;
                frame_batch.push(frame);
                frame_queue.pop();
            }
            break;
        }
        max_len = std::max(max_len, frame->GetLen());
        if(max_len*(bs_acc+1) > max_acc){
            break;
        }
        bs_acc++;
        frame_batch.push(frame);
        frame_queue.pop();
    }
    batch_in = (int)frame_batch.size();
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_batch.front();
            frame_batch.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
void Audio::Padding()
{
    float num_samples = speech_len;
@@ -1085,7 +1169,7 @@
    }
}
void Audio::CutSplit(OfflineStream* offline_stream)
void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
{
    std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
    AudioFrame *frame;
@@ -1112,6 +1196,7 @@
    }    
    int speech_start_i = -1, speech_end_i =-1;
    std::vector<AudioFrame*> vad_frames;
    for(vector<int> vad_segment:vad_segments)
    {
        if(vad_segment.size() != 2){
@@ -1126,16 +1211,31 @@
        }
        if(speech_start_i!=-1 && speech_end_i!=-1){
            frame = new AudioFrame();
            int start = speech_start_i*seg_sample;
            int end = speech_end_i*seg_sample;
            frame = new AudioFrame(end-start);
            frame->SetStart(start);
            frame->SetEnd(end);
            frame_queue.push(frame);
            vad_frames.push_back(frame);
            frame = nullptr;
            speech_start_i=-1;
            speech_end_i=-1;
        }
    }
    // sort
    {
        index_vector.clear();
        index_vector.resize(vad_frames.size());
        for (int i = 0; i < index_vector.size(); ++i) {
            index_vector[i] = i;
        }
        std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
            return vad_frames[a]->len < vad_frames[b]->len;
        });
        for (int idx : index_vector) {
            frame_queue.push(vad_frames[idx]);
        }
    }
}
runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
    _FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
    {
        funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
        funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
        return mm;
    }
@@ -74,16 +74,11 @@
        if(p_result->snippet_time == 0){
            return p_result;
        }
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        while (audio.Fetch(buff, len, flag) > 0) {
            string msg = recog_obj->Forward(buff, len, input_finished);
            p_result->msg += msg;
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        return p_result;
    }
@@ -109,8 +104,6 @@
        float* buff;
        int len;
        int flag = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
        p_result->snippet_time = audio.GetTimeLen();
        if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
        while (audio.Fetch(buff, len, flag) > 0) {
            string msg = recog_obj->Forward(buff, len, true);
            p_result->msg += msg;
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        return p_result;
    }
@@ -244,26 +233,53 @@
        if(p_result->snippet_time == 0){
            return p_result;
        }
        std::vector<int> index_vector={0};
        int msg_idx = 0;
        if(offline_stream->UseVad()){
            audio.CutSplit(offline_stream);
            audio.CutSplit(offline_stream, index_vector);
        }
        std::vector<string> msgs(index_vector.size());
        std::vector<float> msg_stimes(index_vector.size());
        float* buff;
        int len;
        int flag = 0;
        float** buff;
        int* len;
        int* flag;
        float* start_time;
        int batch_size = offline_stream->asr_handle->GetBatchSize();
        int batch_in = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        float start_time = 0.0;
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
        while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
            for(int idx=0; idx<batch_in; idx++){
                string msg = msg_batch[idx];
                if(msg_idx < index_vector.size()){
                    msgs[index_vector[msg_idx]] = msg;
                    msg_stimes[index_vector[msg_idx]] = start_time[idx];
                    msg_idx++;
                }else{
                    LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
                }
            }
            // release
            delete[] buff;
            buff = nullptr;
            delete[] len;
            len = nullptr;
            delete[] flag;
            flag = nullptr;
            delete[] start_time;
            start_time = nullptr;
        }
        for(int idx=0; idx<msgs.size(); idx++){
            string msg = msgs[idx];
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            if(msg_vec.size()==0){
                continue;
@@ -276,14 +292,11 @@
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
                    float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
                }
            }
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(cur_stamp != "["){
            cur_stamp.erase(cur_stamp.length() - 1);
@@ -342,25 +355,53 @@
        if(p_result->snippet_time == 0){
            return p_result;
        }
        std::vector<int> index_vector={0};
        int msg_idx = 0;
        if(offline_stream->UseVad()){
            audio.CutSplit(offline_stream);
            audio.CutSplit(offline_stream, index_vector);
        }
        std::vector<string> msgs(index_vector.size());
        std::vector<float> msg_stimes(index_vector.size());
        float* buff;
        int len;
        int flag = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        float start_time = 0.0;
        float** buff;
        int* len;
        int* flag;
        float* start_time;
        int batch_size = offline_stream->asr_handle->GetBatchSize();
        int batch_in = 0;
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
        while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
            for(int idx=0; idx<batch_in; idx++){
                string msg = msg_batch[idx];
                if(msg_idx < index_vector.size()){
                    msgs[index_vector[msg_idx]] = msg;
                    msg_stimes[index_vector[msg_idx]] = start_time[idx];
                    msg_idx++;
                }else{
                    LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
                }
            }
            // release
            delete[] buff;
            buff = nullptr;
            delete[] len;
            len = nullptr;
            delete[] flag;
            flag = nullptr;
            delete[] start_time;
            start_time = nullptr;
        }
        for(int idx=0; idx<msgs.size(); idx++){
            string msg = msgs[idx];
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            if(msg_vec.size()==0){
                continue;
@@ -373,15 +414,11 @@
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
                    float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
                }
            }
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(cur_stamp != "["){
            cur_stamp.erase(cur_stamp.length() - 1);
@@ -409,7 +446,7 @@
        return p_result;
    }
#if !defined(__APPLE__)
//#if !defined(__APPLE__)
    _FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode)
    {
        if (mode == ASR_OFFLINE){
@@ -433,7 +470,7 @@
        }
        
    }
#endif
//#endif
    // APIs for 2pass-stream Infer
    _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
@@ -518,8 +555,14 @@
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
            float** buff;
            int* len;
            buff = new float*[1];
            len = new int[1];
            buff[0] = frame->data;
            len[0] = frame->len;
            vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
            string msg = msgs.size()>0?msgs[0]:"";
            std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
            if(msg_vec.size()==0){
                continue;
@@ -767,16 +810,45 @@
        funasr::WfstDecoder* mm = nullptr;
        if (asr_type == ASR_OFFLINE) {
            funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
            funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
            if(paraformer !=nullptr){
                if (paraformer->lm_){
                    mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                        paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #ifdef USE_GPU
            auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
            if(paraformer_torch !=nullptr){
                if (paraformer_torch->lm_){
                    mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                        paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #endif
        } else if (asr_type == ASR_TWO_PASS){
            funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
            funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
            if(paraformer !=nullptr){
                if (paraformer->lm_){
                    mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                        paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #ifdef USE_GPU
            auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
            if(paraformer_torch !=nullptr){
                if (paraformer_torch->lm_){
                    mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                        paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #endif
        }
        return mm;
    }
runtime/onnxruntime/src/offline-stream.cpp
@@ -1,7 +1,7 @@
#include "precomp.h"
namespace funasr {
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
    // VAD model
    if(model_path.find(VAD_DIR) != model_path.end()){
@@ -36,7 +36,19 @@
        string hw_compile_model_path;
        string seg_dict_path;
    
        asr_handle = make_unique<Paraformer>();
        if(use_gpu){
            #ifdef USE_GPU
            asr_handle = make_unique<ParaformerTorch>();
            asr_handle->SetBatchSize(batch_size);
            #else
            LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
            asr_handle = make_unique<Paraformer>();
            use_gpu = false;
            #endif
        }else{
            asr_handle = make_unique<Paraformer>();
        }
        bool enable_hotword = false;
        hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
        seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
@@ -54,6 +66,15 @@
          am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
          if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
            am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
          }
          if(use_gpu){
            am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
            if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
                am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
            }
            if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
                am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
            }
          }
        }
        am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
@@ -120,10 +141,10 @@
#endif
}
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
    OfflineStream *mm;
    mm = new OfflineStream(model_path, thread_num);
    mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
    return mm;
}
runtime/onnxruntime/src/paraformer-torch.cpp
New file
@@ -0,0 +1,415 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
#include "paraformer-torch.h"
#include "encode_converter.h"
#include <cstddef>
using namespace std;
namespace funasr {
ParaformerTorch::ParaformerTorch()
:use_hotword(false){
}
// offline
void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
    LoadConfigFromYaml(am_config.c_str());
    // knf options
    fbank_opts_.frame_opts.dither = 0;
    fbank_opts_.mel_opts.num_bins = n_mels;
    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
    fbank_opts_.frame_opts.window_type = window_type;
    fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
    fbank_opts_.frame_opts.frame_length_ms = frame_length;
    fbank_opts_.energy_floor = 0;
    fbank_opts_.mel_opts.debug_mel = false;
    vocab = new Vocab(token_file.c_str());
    phone_set_ = new PhoneSet(token_file.c_str());
    LoadCmvn(am_cmvn.c_str());
    torch::DeviceType device = at::kCPU;
    #ifdef USE_GPU
    if (!torch::cuda::is_available()) {
        LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
        exit(-1);
    } else {
        LOG(INFO) << "CUDA is available, running on GPU";
        device = at::kCUDA;
    }
    #endif
    #ifdef USE_IPEX
    torch::jit::setTensorExprFuserEnabled(false);
    #endif
    try {
        torch::jit::script::Module model = torch::jit::load(am_model, device);
        model_ = std::make_shared<TorchModule>(std::move(model));
        LOG(INFO) << "Successfully load model from " << am_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am model: " << am_model << e.what();
        exit(-1);
    }
}
void ParaformerTorch::InitLm(const std::string &lm_file,
                        const std::string &lm_cfg_file,
                        const std::string &lex_file) {
    try {
        lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
            fst::Fst<fst::StdArc>::Read(lm_file));
        if (lm_){
            lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
            LOG(INFO) << "Successfully load lm file " << lm_file;
        }else{
            LOG(ERROR) << "Failed to load lm file " << lm_file;
        }
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load lm file: " << e.what();
        exit(0);
    }
}
void ParaformerTorch::LoadConfigFromYaml(const char* filename){
    YAML::Node config;
    try{
        config = YAML::LoadFile(filename);
    }catch(exception const &e){
        LOG(ERROR) << "Error loading file, yaml file error or not exist.";
        exit(-1);
    }
    try{
        YAML::Node frontend_conf = config["frontend_conf"];
        this->asr_sample_rate = frontend_conf["fs"].as<int>();
        YAML::Node lang_conf = config["lang"];
        if (lang_conf.IsDefined()){
            language = lang_conf.as<string>();
        }
    }catch(exception const &e){
        LOG(ERROR) << "Error when load argument from vad config YAML.";
        exit(-1);
    }
}
void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
    // TODO
    use_hotword = true;
}
void ParaformerTorch::InitSegDict(const std::string &seg_dict_model) {
    seg_dict = new SegDict(seg_dict_model.c_str());
}
ParaformerTorch::~ParaformerTorch()
{
    if(vocab){
        delete vocab;
    }
    if(lm_vocab){
        delete lm_vocab;
    }
    if(seg_dict){
        delete seg_dict;
    }
    if(phone_set_){
        delete phone_set_;
    }
}
void ParaformerTorch::StartUtterance()
{
}
void ParaformerTorch::EndUtterance()
{
}
void ParaformerTorch::Reset()
{
}
void ParaformerTorch::FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats) {
    knf::OnlineFbank fbank_(fbank_opts_);
    std::vector<float> buf(len);
    for (int32_t i = 0; i != len; ++i) {
        buf[i] = waves[i] * 32768;
    }
    fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
    int32_t frames = fbank_.NumFramesReady();
    for (int32_t i = 0; i != frames; ++i) {
        const float *frame = fbank_.GetFrame(i);
        std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
        asr_feats.emplace_back(frame_vector);
    }
}
void ParaformerTorch::LoadCmvn(const char *filename)
{
    ifstream cmvn_stream(filename);
    if (!cmvn_stream.is_open()) {
        LOG(ERROR) << "Failed to open file: " << filename;
        exit(-1);
    }
    string line;
    while (getline(cmvn_stream, line)) {
        istringstream iss(line);
        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
        if (line_item[0] == "<AddShift>") {
            getline(cmvn_stream, line);
            istringstream means_lines_stream(line);
            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
            if (means_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < means_lines.size() - 1; j++) {
                    means_list_.push_back(stof(means_lines[j]));
                }
                continue;
            }
        }
        else if (line_item[0] == "<Rescale>") {
            getline(cmvn_stream, line);
            istringstream vars_lines_stream(line);
            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
            if (vars_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < vars_lines.size() - 1; j++) {
                    vars_list_.push_back(stof(vars_lines[j])*scale);
                }
                continue;
            }
        }
    }
}
string ParaformerTorch::GreedySearch(float * in, int n_len,  int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
{
    vector<int> hyps;
    int Tmax = n_len;
    for (int i = 0; i < Tmax; i++) {
        int max_idx;
        float max_val;
        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
        hyps.push_back(max_idx);
    }
    if(!is_stamp){
        return vocab->Vector2StringV2(hyps, language);
    }else{
        std::vector<string> char_list;
        std::vector<std::vector<float>> timestamp_list;
        std::string res_str;
        vocab->Vector2String(hyps, char_list);
        std::vector<string> raw_char(char_list);
        TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
        return PostProcess(raw_char, timestamp_list);
    }
}
string ParaformerTorch::BeamSearch(WfstDecoder* &wfst_decoder, float *in, int len, int64_t token_nums)
{
  return wfst_decoder->Search(in, len, token_nums);
}
string ParaformerTorch::FinalizeDecode(WfstDecoder* &wfst_decoder,
                                  bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
{
  return wfst_decoder->FinalizeDecode(is_stamp, us_alphas, us_cif_peak);
}
void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
    std::vector<std::vector<float>> out_feats;
    int T = asr_feats.size();
    int T_lrf = ceil(1.0 * T / lfr_n);
    // Pad frames at start(copy first frame)
    for (int i = 0; i < (lfr_m - 1) / 2; i++) {
        asr_feats.insert(asr_feats.begin(), asr_feats[0]);
    }
    // Merge lfr_m frames as one,lfr_n frames per window
    T = T + (lfr_m - 1) / 2;
    std::vector<float> p;
    for (int i = 0; i < T_lrf; i++) {
        if (lfr_m <= T - i * lfr_n) {
            for (int j = 0; j < lfr_m; j++) {
                p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
            }
            out_feats.emplace_back(p);
            p.clear();
        } else {
            // Fill to lfr_m frames at last window if less than lfr_m frames  (copy last frame)
            int num_padding = lfr_m - (T - i * lfr_n);
            for (int j = 0; j < (asr_feats.size() - i * lfr_n); j++) {
                p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
            }
            for (int j = 0; j < num_padding; j++) {
                p.insert(p.end(), asr_feats[asr_feats.size() - 1].begin(), asr_feats[asr_feats.size() - 1].end());
            }
            out_feats.emplace_back(p);
            p.clear();
        }
    }
    // Apply cmvn
    for (auto &out_feat: out_feats) {
        for (int j = 0; j < means_list_.size(); j++) {
            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
        }
    }
    asr_feats = out_feats;
}
std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
    WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
    int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
    int32_t feature_dim = lfr_m*in_feat_dim;
    std::vector<vector<float>> feats_batch;
    std::vector<int32_t> paraformer_length;
    int max_size = 0;
    int max_frames = 0;
    for(int index=0; index<batch_in; index++){
        std::vector<std::vector<float>> asr_feats;
        FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
        if(asr_feats.size() != 0){
            LfrCmvn(asr_feats);
        }
        int32_t num_frames  = asr_feats.size();
        paraformer_length.emplace_back(num_frames);
        if(max_size < asr_feats.size()*feature_dim){
            max_size = asr_feats.size()*feature_dim;
            max_frames = num_frames;
        }
        std::vector<float> flattened;
        for (const auto& sub_vector : asr_feats) {
            flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
        }
        feats_batch.emplace_back(flattened);
    }
    torch::NoGradGuard no_grad;
    model_->eval();
    // padding
    std::vector<float> all_feats(batch_in * max_frames * feature_dim);
    for(int index=0; index<batch_in; index++){
        feats_batch[index].resize(max_size);
        std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
                        max_frames * feature_dim * sizeof(float));
    }
    torch::Tensor feats =
        torch::from_blob(all_feats.data(),
                {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
    torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
                        {batch_in}, torch::kInt32);
    // 2. forward
    #ifdef USE_GPU
    feats = feats.to(at::kCUDA);
    feat_lens = feat_lens.to(at::kCUDA);
    #endif
    std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
    vector<std::string> results;
    try {
        auto outputs = model_->forward(inputs).toTuple()->elements();
        torch::Tensor am_scores;
        torch::Tensor valid_token_lens;
        #ifdef USE_GPU
        am_scores = outputs[0].toTensor().to(at::kCPU);
        valid_token_lens = outputs[1].toTensor().to(at::kCPU);
        #else
        am_scores = outputs[0].toTensor();
        valid_token_lens = outputs[1].toTensor();
        #endif
        // timestamp
        for(int index=0; index<batch_in; index++){
            string result="";
            if(outputs.size() == 4){
                torch::Tensor us_alphas_tensor;
                torch::Tensor us_peaks_tensor;
                #ifdef USE_GPU
                us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
                us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
                #else
                us_alphas_tensor = outputs[2].toTensor();
                us_peaks_tensor = outputs[3].toTensor();
                #endif
                float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
                std::vector<float> us_alphas(paraformer_length[index]);
                for (int i = 0; i < us_alphas.size(); i++) {
                    us_alphas[i] = us_alphas_data[i];
                }
                float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
                std::vector<float> us_peaks(paraformer_length[index]);
                for (int i = 0; i < us_peaks.size(); i++) {
                    us_peaks[i] = us_peaks_data[i];
                }
                if (lm_ == nullptr) {
                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
                } else {
                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
                    if (input_finished) {
                        result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
                    }
                }
            }else{
                if (lm_ == nullptr) {
                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
                } else {
                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
                    if (input_finished) {
                        result = FinalizeDecode(wfst_decoder);
                    }
                }
            }
            results.push_back(result);
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
        }
    }
    catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
    }
    return results;
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
    // TODO
    std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
    return result;
}
Vocab* ParaformerTorch::GetVocab()
{
    return vocab;
}
Vocab* ParaformerTorch::GetLmVocab()
{
    return lm_vocab;
}
PhoneSet* ParaformerTorch::GetPhoneSet()
{
    return phone_set_;
}
string ParaformerTorch::Rescoring()
{
    LOG(ERROR)<<"Not Imp!!!!!!";
    return "";
}
} // namespace funasr
runtime/onnxruntime/src/paraformer-torch.h
New file
@@ -0,0 +1,96 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#pragma once
#define C10_USE_GLOG
#include <torch/serialize.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include "precomp.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "bias-lm.h"
#include "phone-set.h"
namespace funasr {
    class ParaformerTorch : public Model {
    /**
     * Author: Speech Lab of DAMO Academy, Alibaba Group
     * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     * https://arxiv.org/pdf/2206.08317.pdf
    */
    private:
        Vocab* vocab = nullptr;
        Vocab* lm_vocab = nullptr;
        SegDict* seg_dict = nullptr;
        PhoneSet* phone_set_ = nullptr;
        //const float scale = 22.6274169979695;
        const float scale = 1.0;
        void LoadConfigFromYaml(const char* filename);
        void LoadCmvn(const char *filename);
        void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
        using TorchModule = torch::jit::script::Module;
        std::shared_ptr<TorchModule> model_ = nullptr;
        std::vector<torch::Tensor> encoder_outs_;
        bool use_hotword;
    public:
        ParaformerTorch();
        ~ParaformerTorch();
        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        void InitHwCompiler(const std::string &hw_model, int thread_num);
        void InitSegDict(const std::string &seg_dict_model);
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
        void Reset();
        void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
        string GreedySearch( float* in, int n_len, int64_t token_nums,
                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        string Rescoring();
        string GetLang(){return language;};
        int GetAsrSampleRate() { return asr_sample_rate; };
        void SetBatchSize(int batch_size) {batch_size_ = batch_size;};
        int GetBatchSize() {return batch_size_;};
        void StartUtterance();
        void EndUtterance();
        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
        string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
        string FinalizeDecode(WfstDecoder* &wfst_decoder,
                          bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        Vocab* GetVocab();
        Vocab* GetLmVocab();
        PhoneSet* GetPhoneSet();
        knf::FbankOptions fbank_opts_;
        vector<float> means_list_;
        vector<float> vars_list_;
        int lfr_m = PARA_LFR_M;
        int lfr_n = PARA_LFR_N;
        // paraformer-offline
        std::string language="zh-cn";
        // lm
        std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
        string window_type = "hamming";
        int frame_length = 25;
        int frame_shift = 10;
        int n_mels = 80;
        int encoder_size = 512;
        int fsmn_layers = 16;
        int fsmn_lorder = 10;
        int fsmn_dims = 512;
        float cif_threshold = 1.0;
        float tail_alphas = 0.45;
        int asr_sample_rate = MODEL_SAMPLE_RATE;
        int batch_size_ = 1;
    };
} // namespace funasr
runtime/onnxruntime/src/paraformer.cpp
@@ -462,15 +462,23 @@
    asr_feats = out_feats;
}
string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
    std::vector<std::string> results;
    string result="";
    WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
    int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
    if(batch_in != 1){
        results.push_back(result);
        return results;
    }
    std::vector<std::vector<float>> asr_feats;
    FbankKaldi(asr_sample_rate, din, len, asr_feats);
    FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
    if(asr_feats.size() == 0){
      return "";
        results.push_back(result);
        return results;
    }
    LfrCmvn(asr_feats);
    int32_t feat_dim = lfr_m*in_feat_dim;
@@ -509,7 +517,8 @@
        if (use_hotword) {
            if(hw_emb.size()<=0){
                LOG(ERROR) << "hw_emb is null";
                return "";
                results.push_back(result);
                return results;
            }
            //PrintMat(hw_emb, "input_clas_emb");
            const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
@@ -526,10 +535,10 @@
    }catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
        return "";
        results.push_back(result);
        return results;
    }
    string result="";
    try {
        auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -577,7 +586,8 @@
        LOG(ERROR)<<e.what();
    }
    return result;
    results.push_back(result);
    return results;
}
runtime/onnxruntime/src/paraformer.h
@@ -52,13 +52,14 @@
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
        void Reset();
        void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
        string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
        string GreedySearch( float* in, int n_len, int64_t token_nums,
                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        string Rescoring();
        string GetLang(){return language;};
        int GetAsrSampleRate() { return asr_sample_rate; };
        int GetBatchSize() {return batch_size_;};
        void StartUtterance();
        void EndUtterance();
        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -110,6 +111,7 @@
        float cif_threshold = 1.0;
        float tail_alphas = 0.45;
        int asr_sample_rate = MODEL_SAMPLE_RATE;
        int batch_size_ = 1;
    };
} // namespace funasr
runtime/onnxruntime/src/precomp.h
@@ -64,6 +64,9 @@
#include "seg_dict.h"
#include "resample.h"
#include "paraformer.h"
#ifdef USE_GPU
#include "paraformer-torch.h"
#endif
#include "paraformer-online.h"
#include "offline-stream.h"
#include "tpass-stream.h"
runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
@@ -70,13 +70,13 @@
  return os;
}
#ifndef USE_GPU
template<class T1, class T2>
ostream& operator << (ostream& os, const pair<T1, T2>& pr) {
  os << pr.first << ":" << pr.second ;
  return os;
}
#endif
template<class T>
string& operator << (string& str, const T& obj) {
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -326,6 +326,9 @@
    def __call__(
        self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
    ) -> List:
    # def __call__(
    #     self, waveform_list:list, hotwords: str, **kwargs
    # ) -> List:
        # make hotword list
        hotwords, hotwords_length = self.proc_hotword(hotwords)
        # import pdb; pdb.set_trace()
@@ -345,15 +348,47 @@
            try:
                outputs = self.bb_infer(feats, feats_len, bias_embed)
                am_scores, valid_token_lens = outputs[0], outputs[1]
                if len(outputs) == 4:
                    # for BiCifParaformer Inference
                    us_alphas, us_peaks = outputs[2], outputs[3]
                else:
                    us_alphas, us_peaks = None, None
            except ONNXRuntimeError:
                # logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = [""]
            else:
                preds = self.decode(am_scores, valid_token_lens)
                for pred in preds:
                    pred = sentence_postprocess(pred)
                    asr_res.append({"preds": pred})
                if us_peaks is None:
                    for pred in preds:
                        if self.language == "en-bpe":
                            pred = sentence_postprocess_sentencepiece(pred)
                        else:
                            pred = sentence_postprocess(pred)
                        asr_res.append({"preds": pred})
                else:
                    for pred, us_peaks_ in zip(preds, us_peaks):
                        raw_tokens = pred
                        timestamp, timestamp_raw = time_stamp_lfr6_onnx(
                            us_peaks_, copy.copy(raw_tokens)
                        )
                        text_proc, timestamp_proc, _ = sentence_postprocess(
                            raw_tokens, timestamp_raw
                        )
                        # logging.warning(timestamp)
                        if len(self.plot_timestamp_to):
                            self.plot_wave_timestamp(
                                waveform_list[0], timestamp, self.plot_timestamp_to
                            )
                        asr_res.append(
                            {
                                "preds": text_proc,
                                "timestamp": timestamp_proc,
                                "raw_tokens": raw_tokens,
                            }
                        )
        return asr_res
    def proc_hotword(self, hotwords):
runtime/websocket/CMakeLists.txt
@@ -8,6 +8,10 @@
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(GPU "Whether to build with GPU" OFF)
if(WIN32)
  file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h 
@@ -20,12 +24,16 @@
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
endif()
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
if(GPU)
    add_definitions(-DUSE_GPU)
    set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
    set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
    include_directories(${TORCH_DIR}/include)
    include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
    link_directories(${TORCH_DIR}/lib)
    link_directories(${TORCH_BLADE_DIR})
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
 
if(ENABLE_WEBSOCKET)
  # cmake_policy(SET CMP0135 NEW)
runtime/websocket/bin/CMakeLists.txt
@@ -1,5 +1,4 @@
if(WIN32)
  include_directories(${ONNXRUNTIME_DIR}/include)
  include_directories(${FFMPEG_DIR}/include)
@@ -12,15 +11,14 @@
  SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
endif()
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client "funasr-wss-client.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp" ${RELATION_SOURCE})
target_link_options(funasr-wss-server PRIVATE "-Wl,--no-as-needed")
target_link_options(funasr-wss-server-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-wss-client PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY} portaudio)
target_link_libraries(funasr-wss-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
runtime/websocket/bin/funasr-wss-server.cpp
@@ -56,6 +56,10 @@
        "true (Default), load the model of model_quant.onnx in model_dir. If set "
        "false, load the model of model.onnx in model_dir",
        false, "true", "string");
    TCLAP::ValueArg<std::string> bladedisc(
        "", BLADEDISC,
        "true (Default), load the model of bladedisc in model_dir.",
        false, "true", "string");
    TCLAP::ValueArg<std::string> vad_dir(
        "", VAD_DIR,
        "default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn",
@@ -121,6 +125,8 @@
        false, "/workspace/resources/hotwords.txt", "string");
    TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, 
        "the fst hotwords incremental bias", false, 20, "int32_t");
    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU, default is false", false);
    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
    // add file
    cmd.add(hotword);
@@ -135,6 +141,7 @@
    cmd.add(model_dir);
    cmd.add(model_revision);
    cmd.add(quantize);
    cmd.add(bladedisc);
    cmd.add(vad_dir);
    cmd.add(vad_revision);
    cmd.add(vad_quant);
@@ -151,11 +158,14 @@
    cmd.add(io_thread_num);
    cmd.add(decoder_thread_num);
    cmd.add(model_thread_num);
    cmd.add(use_gpu);
    cmd.add(batch_size);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(bladedisc, BLADEDISC, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
@@ -173,6 +183,8 @@
    global_beam_ = global_beam.getValue();
    lattice_beam_ = lattice_beam.getValue();
    am_scale_ = am_scale.getValue();
    bool use_gpu_ = use_gpu.getValue();
    int batch_size_ = batch_size.getValue();
    // Download model form Modelscope
    try{
@@ -468,7 +480,7 @@
    WebSocketServer websocket_srv(
        io_decoder, is_ssl, server, wss_server, s_certfile,
        s_keyfile);  // websocket server for asr engine
    websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
    websocket_srv.initAsr(model_path, s_model_thread_num, use_gpu_, batch_size_);  // init asr model
    LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
    LOG(INFO) << "io-thread-num: " << s_io_thread_num;
runtime/websocket/bin/websocket-server.cpp
@@ -402,11 +402,11 @@
// init asr model
void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
                              int thread_num) {
                              int thread_num, bool use_gpu, int batch_size) {
  try {
    // init model with api
    asr_handle = FunOfflineInit(model_path, thread_num);
    asr_handle = FunOfflineInit(model_path, thread_num, use_gpu, batch_size);
    LOG(INFO) << "model successfully inited";
    
    LOG(INFO) << "initAsr run check_and_clean_connection";
runtime/websocket/bin/websocket-server.h
@@ -124,7 +124,7 @@
                  std::string wav_format,
                  FUNASR_DEC_HANDLE& decoder_handle);
  void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
  void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
  void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
  void on_open(websocketpp::connection_hdl hdl);
  void on_close(websocketpp::connection_hdl hdl);
setup.py
@@ -39,7 +39,7 @@
        "jaconv",
        "hydra-core>=1.3.2",
        "tensorboardX",
        "rotary_embedding_torch",
        # "rotary_embedding_torch",
        "openai-whisper",
    ],
    # train: The modules invoked when training only.