zhifu gao
2024-03-11 95cf2646fa6dae67bf53354f4ed5e81780d8fee9
onnx (#1460)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx

* export onnx
29个文件已修改
2个文件已添加
1个文件已删除
950 ■■■■ 已修改文件
README.md 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README_zh.md 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/export.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/export.sh 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer/export.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer/export.sh 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer_streaming/export.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.sh 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer_streaming/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer_streaming/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer_streaming/export.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer_streaming/export.sh 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/whisper/demo.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/whisper/demo_from_openai.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/whisper/infer.sh 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/whisper/infer_from_local.sh 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/whisper/infer_from_openai.sh 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/download/name_maps_from_hub.py 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/decoder.py 278 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer_streaming/model.py 130 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/attention.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/attention_export.py 114 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/encoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/attention.py 100 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/decoder.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/punc_bin.py 19 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/vad_bin.py 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -210,7 +210,22 @@
More examples ref to [docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
[//]: # (FunASR supports inference and fine-tuning of models trained on industrial datasets of tens of thousands of hours. For more details, please refer to ([modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)). It also supports training and fine-tuning of models on academic standard datasets. For more details, please refer to([egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)). The models include speech recognition (ASR), speech activity detection (VAD), punctuation recovery, language model, speaker verification, speaker separation, and multi-party conversation speech recognition. For a detailed list of models, please refer to the [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md):)
## Export ONNX
### Command-line usage
```shell
funasr-export ++model=paraformer ++quantize=false
```
### python
```python
from funasr import AutoModel
model = AutoModel(model="paraformer")
res = model.export(quantize=False)
```
## Deployment Service
FunASR supports deploying pre-trained or further fine-tuned models for service. Currently, it supports the following types of service deployment:
README_zh.md
@@ -210,6 +210,21 @@
```
更多详细用法([示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
## 导出ONNX
### 从命令行导出
```shell
funasr-export ++model=paraformer ++quantize=false
```
### 从python指令导出
```python
from funasr import AutoModel
model = AutoModel(model="paraformer")
res = model.export(quantize=False)
```
<a name="服务部署"></a>
## 服务部署
examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -6,21 +6,18 @@
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
# # method2, inference from local path
# from funasr import AutoModel
#
# wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
#
# model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
#
# res = model.export(input=wav_file, type="onnx", quantize=False)
# print(res)
# method2, inference from local path
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
res = model.export(type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/bicif_paraformer/export.sh
@@ -11,18 +11,13 @@
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
++quantize=false
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
++quantize=false
examples/industrial_data_pretraining/ct_transformer/export.py
@@ -6,21 +6,18 @@
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/ct_transformer/export.sh
@@ -5,24 +5,20 @@
export HYDRA_FULL_ERROR=1
model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model_revision="v2.0.4"
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
++quantize=false
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
++quantize=false
examples/industrial_data_pretraining/ct_transformer_streaming/export.py
@@ -6,21 +6,18 @@
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
@@ -7,20 +7,17 @@
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
@@ -12,18 +12,13 @@
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
++quantize=false
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
++quantize=false
examples/industrial_data_pretraining/paraformer/export.py
@@ -8,21 +8,18 @@
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
res = model.export(input=wav_file, type="onnx", quantize=False)
res = model.export(type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/paraformer/export.sh
@@ -12,10 +12,8 @@
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
++quantize=false
# method2, inference from local path
@@ -23,8 +21,5 @@
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu" \
++debug=false
++quantize=false
examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -9,7 +9,7 @@
encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
            chunk_size=chunk_size,
            encoder_chunk_look_back=encoder_chunk_look_back,
examples/industrial_data_pretraining/paraformer_streaming/demo.sh
@@ -1,5 +1,5 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model_revision="v2.0.4"
python funasr/bin/inference.py \
examples/industrial_data_pretraining/paraformer_streaming/export.py
New file
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
from funasr import AutoModel
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online",
                  model_revision="v2.0.4")
res = model.export(type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
res = model.export(type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/paraformer_streaming/export.sh
New file
@@ -0,0 +1,28 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
export HYDRA_FULL_ERROR=1
model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online"
model_revision="v2.0.4"
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++type="onnx" \
++quantize=false \
++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online"
python -m funasr.bin.export \
++model=${model} \
++type="onnx" \
++quantize=false \
++device="cpu" \
++debug=false
examples/industrial_data_pretraining/whisper/demo.py
@@ -3,6 +3,8 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# To install requirements: pip3 install -U openai-whisper
from funasr import AutoModel
model = AutoModel(model="iic/Whisper-large-v3",
examples/industrial_data_pretraining/whisper/demo_from_openai.py
@@ -3,6 +3,8 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# To install requirements: pip3 install -U openai-whisper
from funasr import AutoModel
# model = AutoModel(model="Whisper-small", hub="openai")
examples/industrial_data_pretraining/whisper/infer.sh
@@ -1,6 +1,8 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# To install requirements: pip3 install -U openai-whisper
# method1, inference from model hub
# for more input type, please ref to readme.md
examples/industrial_data_pretraining/whisper/infer_from_local.sh
@@ -1,6 +1,8 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# To install requirements: pip3 install -U openai-whisper
# method2, inference from local model
# for more input type, please ref to readme.md
examples/industrial_data_pretraining/whisper/infer_from_openai.sh
@@ -1,6 +1,8 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# To install requirements: pip3 install -U openai-whisper
# method1, inference from model hub
# for more input type, please ref to readme.md
funasr/download/name_maps_from_hub.py
@@ -1,13 +1,14 @@
name_maps_ms = {
    "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
    "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
    "cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
    "paraformer": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-zh": "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-en": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-en-spk": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-zh-streaming": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
    "fsmn-vad": "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    "ct-punc": "iic/punc_ct-transformer_cn-en-common-vocab471067-large",
    "ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
    "fa-zh": "iic/speech_timestamp_prediction-v1-16k-offline",
    "cam++": "iic/speech_campplus_sv_zh-cn_16k-common",
    "Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
    "Whisper-large-v3": "iic/Whisper-large-v3",
    "Qwen-Audio": "Qwen/Qwen-Audio",
funasr/models/paraformer/decoder.py
@@ -623,7 +623,9 @@
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True, ):
                 onnx: bool = True,
                 **kwargs
                 ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        from funasr.utils.torch_function import MakePadMask
@@ -752,6 +754,162 @@
        })
        return ret
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        *args,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        out_caches = list()
        for i, decoder in enumerate(self.model.decoders):
            in_cache = args[i]
            x, tgt_mask, memory, memory_mask, out_cache = decoder(
                x, tgt_mask, memory, memory_mask, cache=in_cache
            )
            out_caches.append(out_cache)
        if self.model.decoders2 is not None:
            for i, decoder in enumerate(self.model.decoders2):
                in_cache = args[i + len(self.model.decoders)]
                x, tgt_mask, memory, memory_mask, out_cache = decoder(
                    x, tgt_mask, memory, memory_mask, cache=in_cache
                )
                out_caches.append(out_cache)
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, out_caches
    def get_dummy_inputs(self, enc_size):
        enc = torch.randn(2, 100, enc_size).type(torch.float32)
        enc_len = torch.tensor([30, 100], dtype=torch.int32)
        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        cache = [
            torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
                        dtype=torch.float32)
            for _ in range(cache_num)
        ]
        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
    def get_input_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
               + ['in_cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ['logits', 'sample_ids'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'enc': {
                0: 'batch_size',
                1: 'enc_length'
            },
            'acoustic_embeds': {
                0: 'batch_size',
                1: 'token_length'
            },
            'enc_len': {
                0: 'batch_size',
            },
            'acoustic_embeds_len': {
                0: 'batch_size',
            },
        }
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        ret.update({
            'in_cache_%d' % d: {
                0: 'batch_size',
            }
            for d in range(cache_num)
        })
        ret.update({
            'out_cache_%d' % d: {
                0: 'batch_size',
            }
            for d in range(cache_num)
        })
        return ret
@tables.register("decoder_classes", "ParaformerSANDecoder")
@@ -868,3 +1026,121 @@
        else:
            return x, olens
@tables.register("decoder_classes", "ParaformerDecoderSANExport")
class ParaformerDecoderSANExport(torch.nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True, ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.transformer.decoder import DecoderLayerExport
        from funasr.models.transformer.attention import MultiHeadedAttentionExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.src_attn, MultiHeadedAttention):
                d.src_attn = MultiHeadedAttentionExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, ys_in_lens
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        return ret
funasr/models/paraformer_streaming/model.py
@@ -561,4 +561,134 @@
        return result, meta_data
    def export(
        self,
        max_seq_len=512,
        **kwargs,
    ):
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
        if kwargs["decoder"] == "ParaformerSANMDecoder":
            kwargs["decoder"] = "ParaformerSANMDecoderOnline"
        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        if is_onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.forward = self._export_forward
        import copy
        import types
        encoder_model = copy.copy(self)
        decoder_model = copy.copy(self)
        # encoder
        encoder_model.forward = types.MethodType(ParaformerStreaming._export_encoder_forward, encoder_model)
        encoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_encoder_dummy_inputs, encoder_model)
        encoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_encoder_input_names, encoder_model)
        encoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_encoder_output_names, encoder_model)
        encoder_model.export_dynamic_axes = types.MethodType(ParaformerStreaming.export_encoder_dynamic_axes, encoder_model)
        encoder_model.export_name = types.MethodType(ParaformerStreaming.export_encoder_name, encoder_model)
        # decoder
        decoder_model.forward = types.MethodType(ParaformerStreaming._export_decoder_forward, decoder_model)
        decoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_decoder_dummy_inputs, decoder_model)
        decoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_decoder_input_names, decoder_model)
        decoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_decoder_output_names, decoder_model)
        decoder_model.export_dynamic_axes = types.MethodType(ParaformerStreaming.export_decoder_dynamic_axes, decoder_model)
        decoder_model.export_name = types.MethodType(ParaformerStreaming.export_decoder_name, decoder_model)
        return encoder_model, decoder_model
    def _export_encoder_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
        # batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        alphas, _ = self.predictor.forward_cnn(enc, mask)
        return enc, enc_len, alphas
    def export_encoder_dummy_inputs(self):
        speech = torch.randn(2, 30, 560)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def export_encoder_input_names(self):
        return ['speech', 'speech_lengths']
    def export_encoder_output_names(self):
        return ['enc', 'enc_len', 'alphas']
    def export_encoder_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'enc': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'enc_len': {
                0: 'batch_size',
            },
            'alphas': {
                0: 'batch_size',
                1: 'feats_length'
            },
        }
    def export_encoder_name(self):
        return "model.onnx"
    def _export_decoder_forward(
        self,
        enc: torch.Tensor,
        enc_len: torch.Tensor,
        acoustic_embeds: torch.Tensor,
        acoustic_embeds_len: torch.Tensor,
        *args,
    ):
        decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
        sample_ids = decoder_out.argmax(dim=-1)
        return decoder_out, sample_ids, out_caches
    def export_decoder_dummy_inputs(self):
        dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.encoder._output_size)
        return dummy_inputs
    def export_decoder_input_names(self):
        return self.decoder.get_input_names()
    def export_decoder_output_names(self):
        return self.decoder.get_output_names()
    def export_decoder_dynamic_axes(self):
        return self.decoder.get_dynamic_axes()
    def export_decoder_name(self):
        return "decoder.onnx"
funasr/models/sanm/attention.py
@@ -831,6 +831,3 @@
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
        return att_outs
funasr/models/sanm/attention_export.py
File was deleted
funasr/models/sanm/encoder.py
@@ -484,7 +484,7 @@
        return x, mask
@tables.register("encoder_classes", "SANMEncoderChunkOptExport")
@tables.register("encoder_classes", "SANMEncoderExport")
class SANMEncoderExport(nn.Module):
    def __init__(
funasr/models/transformer/attention.py
@@ -118,6 +118,106 @@
        return self.forward_attention(v, scores, mask)
class MultiHeadedAttentionExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_q = model.linear_q
        self.linear_k = model.linear_k
        self.linear_v = model.linear_v
        self.linear_out = model.linear_out
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, query, key, value, mask):
        q, k, v = self.forward_qkv(query, key, value)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, mask)
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, query, key, value):
        q = self.linear_q(query)
        k = self.linear_k(key)
        v = self.linear_v(value)
        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)
        return q, k, v
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class RelPosMultiHeadedAttentionExport(MultiHeadedAttentionExport):
    def __init__(self, model):
        super().__init__(model)
        self.linear_pos = model.linear_pos
        self.pos_bias_u = model.pos_bias_u
        self.pos_bias_v = model.pos_bias_v
    def forward(self, query, key, value, pos_emb, mask):
        q, k, v = self.forward_qkv(query, key, value)
        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
        p = self.transpose_for_scores(self.linear_pos(pos_emb))  # (batch, head, time1, d_k)
        # (batch, head, time1, d_k)
        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
        # (batch, head, time1, d_k)
        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
        # compute attention score
        # first compute matrix a and matrix c
        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
        # (batch, head, time1, time2)
        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
        # compute matrix b and matrix d
        # (batch, head, time1, time1)
        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
        matrix_bd = self.rel_shift(matrix_bd)
        scores = (matrix_ac + matrix_bd) / math.sqrt(
            self.d_k
        )  # (batch, head, time1, time2)
        return self.forward_attention(v, scores, mask)
    def rel_shift(self, x):
        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=-1)
        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
        x = x_padded[:, :, 1:].view_as(x)[
            :, :, :, : x.size(-1) // 2 + 1
            ]  # only keep the positions from 0 to time2
        return x
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
    """Multi-Head Attention layer with relative position encoding (old version).
funasr/models/transformer/decoder.py
@@ -150,6 +150,35 @@
        return x, tgt_mask, memory, memory_mask
class DecoderLayerExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.norm3 = model.norm3
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt_q = tgt
        tgt_q_mask = tgt_mask
        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
        residual = x
        x = self.norm2(x)
        x = residual + self.src_attn(x, memory, memory, memory_mask)
        residual = x
        x = self.norm3(x)
        x = residual + self.feed_forward(x)
        return x, tgt_mask, memory, memory_mask
class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
    """Base class of Transfomer decoder module.
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -5,6 +5,7 @@
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import librosa
@@ -55,25 +56,24 @@
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=cache_dir)
            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
            
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -3,7 +3,7 @@
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import librosa
import numpy as np
@@ -48,25 +48,24 @@
        if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=cache_dir)
            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontendOnline(
            cmvn_file=cmvn_file,
runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -6,7 +6,7 @@
from pathlib import Path
from typing import List, Union, Tuple
import numpy as np
import json
from .utils.utils import (ONNXRuntimeError,
                          OrtInferSession, get_logger,
                          read_yaml)
@@ -48,24 +48,23 @@
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=cache_dir)
            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
            
        config_file = os.path.join(model_dir, 'punc.yaml')
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, 'tokens.json')
        with open(token_list, 'r', encoding='utf-8') as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = 1
        self.punc_list = config['punc_list']
runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -54,19 +54,15 @@
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=cache_dir)
            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
        config_file = os.path.join(model_dir, 'vad.yaml')
        cmvn_file = os.path.join(model_dir, 'vad.mvn')
        config = read_yaml(config_file)
@@ -222,19 +218,16 @@
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
            model = AutoModel(model=cache_dir)
            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
        config_file = os.path.join(model_dir, 'vad.yaml')
        cmvn_file = os.path.join(model_dir, 'vad.mvn')
        config = read_yaml(config_file)