From 95cf2646fa6dae67bf53354f4ed5e81780d8fee9 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 14:43:08 +0800
Subject: [PATCH] onnx (#1460)

---
 examples/industrial_data_pretraining/bicif_paraformer/export.sh         |    9 
 examples/industrial_data_pretraining/whisper/infer_from_openai.sh       |    2 
 funasr/models/paraformer_streaming/model.py                             |  130 ++++++++
 examples/industrial_data_pretraining/paraformer/export.py               |    7 
 examples/industrial_data_pretraining/whisper/demo.py                    |    2 
 examples/industrial_data_pretraining/ct_transformer/export.py           |   11 
 examples/industrial_data_pretraining/bicif_paraformer/export.py         |   19 
 examples/industrial_data_pretraining/paraformer_streaming/export.py     |   26 +
 funasr/models/transformer/attention.py                                  |  100 ++++++
 funasr/models/transformer/decoder.py                                    |   29 +
 examples/industrial_data_pretraining/whisper/infer.sh                   |    2 
 runtime/python/onnxruntime/funasr_onnx/punc_bin.py                      |   19 
 examples/industrial_data_pretraining/paraformer_streaming/demo.sh       |    2 
 runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py         |   19 
 examples/industrial_data_pretraining/paraformer_streaming/export.sh     |   28 +
 examples/industrial_data_pretraining/paraformer/export.sh               |    9 
 examples/industrial_data_pretraining/ct_transformer/export.sh           |   12 
 examples/industrial_data_pretraining/paraformer_streaming/demo.py       |    2 
 funasr/models/sanm/attention.py                                         |    3 
 examples/industrial_data_pretraining/ct_transformer_streaming/export.py |    7 
 runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py                |   18 
 README_zh.md                                                            |   15 
 examples/industrial_data_pretraining/whisper/demo_from_openai.py        |    2 
 examples/industrial_data_pretraining/fsmn_vad_streaming/export.py       |   11 
 funasr/models/paraformer/decoder.py                                     |  278 +++++++++++++++++
 README.md                                                               |   17 +
 runtime/python/onnxruntime/funasr_onnx/vad_bin.py                       |   25 -
 /dev/null                                                               |  114 -------
 examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh       |    9 
 funasr/download/name_maps_from_hub.py                                   |   19 
 examples/industrial_data_pretraining/whisper/infer_from_local.sh        |    2 
 funasr/models/sanm/encoder.py                                           |    2 
 32 files changed, 710 insertions(+), 240 deletions(-)

diff --git a/README.md b/README.md
index d34249d..3bd52ea 100644
--- a/README.md
+++ b/README.md
@@ -210,7 +210,22 @@
 
 More examples ref to [docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
 
-[//]: # (FunASR supports inference and fine-tuning of models trained on industrial datasets of tens of thousands of hours. For more details, please refer to &#40;[modelscope_egs]&#40;https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html&#41;&#41;. It also supports training and fine-tuning of models on academic standard datasets. For more details, please refer to&#40;[egs]&#40;https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html&#41;&#41;. The models include speech recognition &#40;ASR&#41;, speech activity detection &#40;VAD&#41;, punctuation recovery, language model, speaker verification, speaker separation, and multi-party conversation speech recognition. For a detailed list of models, please refer to the [Model Zoo]&#40;https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md&#41;:)
+
+## Export ONNX
+
+### Command-line usage
+```shell
+funasr-export ++model=paraformer ++quantize=false
+```
+
+### python
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="paraformer")
+
+res = model.export(quantize=False)
+```
 
 ## Deployment Service
 FunASR supports deploying pre-trained or further fine-tuned models for service. Currently, it supports the following types of service deployment:
diff --git a/README_zh.md b/README_zh.md
index 83e37fb..4f4e15d 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -210,6 +210,21 @@
 ```
 鏇村璇︾粏鐢ㄦ硶锛圼绀轰緥](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)锛�
 
+## 瀵煎嚭ONNX
+### 浠庡懡浠よ瀵煎嚭
+```shell
+funasr-export ++model=paraformer ++quantize=false
+```
+
+### 浠巔ython鎸囦护瀵煎嚭
+```python
+from funasr import AutoModel
+
+model = AutoModel(model="paraformer")
+
+res = model.export(quantize=False)
+```
+
 
 <a name="鏈嶅姟閮ㄧ讲"></a>
 ## 鏈嶅姟閮ㄧ讲
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
index 78e7295..138f23a 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -6,21 +6,18 @@
 # method1, inference from model hub
 
 from funasr import AutoModel
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
 
 model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                   model_revision="v2.0.4")
 
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
 
 
-# # method2, inference from local path
-# from funasr import AutoModel
-#
-# wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
-#
-# model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
-#
-# res = model.export(input=wav_file, type="onnx", quantize=False)
-# print(res)
\ No newline at end of file
+# method2, inference from local path
+from funasr import AutoModel
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+
+res = model.export(type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.sh b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
index bc20a90..42b6348 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.sh
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
@@ -11,18 +11,13 @@
 python -m funasr.bin.export \
 ++model=${model} \
 ++model_revision=${model_revision} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu"
-
+++quantize=false
 
 # method2, inference from local path
 model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
 
 python -m funasr.bin.export \
 ++model=${model} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu"
\ No newline at end of file
+++quantize=false
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.py b/examples/industrial_data_pretraining/ct_transformer/export.py
index 3321525..8c35670 100644
--- a/examples/industrial_data_pretraining/ct_transformer/export.py
+++ b/examples/industrial_data_pretraining/ct_transformer/export.py
@@ -6,21 +6,18 @@
 # method1, inference from model hub
 
 from funasr import AutoModel
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
 
-model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                   model_revision="v2.0.4")
 
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
 
 
 # method2, inference from local path
 from funasr import AutoModel
 
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
 
-model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
-
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.sh b/examples/industrial_data_pretraining/ct_transformer/export.sh
index f7849a1..7556458 100644
--- a/examples/industrial_data_pretraining/ct_transformer/export.sh
+++ b/examples/industrial_data_pretraining/ct_transformer/export.sh
@@ -5,24 +5,20 @@
 export HYDRA_FULL_ERROR=1
 
 
-model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
 model_revision="v2.0.4"
 
 python -m funasr.bin.export \
 ++model=${model} \
 ++model_revision=${model_revision} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu"
+++quantize=false
 
 
 # method2, inference from local path
-model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
 
 python -m funasr.bin.export \
 ++model=${model} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu"
\ No newline at end of file
+++quantize=false
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/export.py b/examples/industrial_data_pretraining/ct_transformer_streaming/export.py
index 4e50501..47fa08a 100644
--- a/examples/industrial_data_pretraining/ct_transformer_streaming/export.py
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/export.py
@@ -6,21 +6,18 @@
 # method1, inference from model hub
 
 from funasr import AutoModel
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
 
 model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
                   model_revision="v2.0.4")
 
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
 
 
 # method2, inference from local path
 from funasr import AutoModel
 
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
-
 model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727")
 
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
index 2e09523..2c8fd4d 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
@@ -7,20 +7,17 @@
 # method1, inference from model hub
 
 from funasr import AutoModel
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
 
-model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
+model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
 
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
 
 # method2, inference from local path
 
 from funasr import AutoModel
 
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
 
-model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch")
-
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
index 911a1a1..1a8207a 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
@@ -12,18 +12,13 @@
 python -m funasr.bin.export \
 ++model=${model} \
 ++model_revision=${model_revision} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu"
-
+++quantize=false
 
 # method2, inference from local path
 model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
 
 python -m funasr.bin.export \
 ++model=${model} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu"
\ No newline at end of file
+++quantize=false
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
index 43b3c18..fce4d77 100644
--- a/examples/industrial_data_pretraining/paraformer/export.py
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -8,21 +8,18 @@
 
 
 from funasr import AutoModel
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
 
 model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                   model_revision="v2.0.4")
 
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
 
 
 # method2, inference from local path
 from funasr import AutoModel
 
-wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
-
 model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
 
-res = model.export(input=wav_file, type="onnx", quantize=False)
+res = model.export(type="onnx", quantize=False)
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.sh b/examples/industrial_data_pretraining/paraformer/export.sh
index c67dca8..fc341e7 100644
--- a/examples/industrial_data_pretraining/paraformer/export.sh
+++ b/examples/industrial_data_pretraining/paraformer/export.sh
@@ -12,10 +12,8 @@
 python -m funasr.bin.export \
 ++model=${model} \
 ++model_revision=${model_revision} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu"
+++quantize=false
 
 
 # method2, inference from local path
@@ -23,8 +21,5 @@
 
 python -m funasr.bin.export \
 ++model=${model} \
-++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
-++quantize=false \
-++device="cpu" \
-++debug=false
+++quantize=false
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 455fe84..9885c0b 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -9,7 +9,7 @@
 encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
 decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
 
-model = AutoModel(model="damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.4")
+model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.4")
 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
             chunk_size=chunk_size,
             encoder_chunk_look_back=encoder_chunk_look_back,
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.sh b/examples/industrial_data_pretraining/paraformer_streaming/demo.sh
index edb7196..c3f7bb4 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.sh
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.sh
@@ -1,5 +1,5 @@
 
-model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
 model_revision="v2.0.4"
 
 python funasr/bin/inference.py \
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/export.py b/examples/industrial_data_pretraining/paraformer_streaming/export.py
new file mode 100644
index 0000000..8e22310
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer_streaming/export.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+
+# method1, inference from model hub
+
+
+from funasr import AutoModel
+
+model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online",
+                  model_revision="v2.0.4")
+
+res = model.export(type="onnx", quantize=False)
+print(res)
+
+
+# method2, inference from local path
+from funasr import AutoModel
+
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online")
+
+res = model.export(type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/export.sh b/examples/industrial_data_pretraining/paraformer_streaming/export.sh
new file mode 100644
index 0000000..43e344b
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer_streaming/export.sh
@@ -0,0 +1,28 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online"
+model_revision="v2.0.4"
+
+
+python -m funasr.bin.export \
+++model=${model} \
+++model_revision=${model_revision} \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online"
+
+python -m funasr.bin.export \
+++model=${model} \
+++type="onnx" \
+++quantize=false \
+++device="cpu" \
+++debug=false
diff --git a/examples/industrial_data_pretraining/whisper/demo.py b/examples/industrial_data_pretraining/whisper/demo.py
index db8d92c..01e125d 100644
--- a/examples/industrial_data_pretraining/whisper/demo.py
+++ b/examples/industrial_data_pretraining/whisper/demo.py
@@ -3,6 +3,8 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+# To install requirements: pip3 install -U openai-whisper
+
 from funasr import AutoModel
 
 model = AutoModel(model="iic/Whisper-large-v3",
diff --git a/examples/industrial_data_pretraining/whisper/demo_from_openai.py b/examples/industrial_data_pretraining/whisper/demo_from_openai.py
index 046e9c6..2ee8ad5 100644
--- a/examples/industrial_data_pretraining/whisper/demo_from_openai.py
+++ b/examples/industrial_data_pretraining/whisper/demo_from_openai.py
@@ -3,6 +3,8 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+# To install requirements: pip3 install -U openai-whisper
+
 from funasr import AutoModel
 
 # model = AutoModel(model="Whisper-small", hub="openai")
diff --git a/examples/industrial_data_pretraining/whisper/infer.sh b/examples/industrial_data_pretraining/whisper/infer.sh
index 11e66c7..5beb7e2 100644
--- a/examples/industrial_data_pretraining/whisper/infer.sh
+++ b/examples/industrial_data_pretraining/whisper/infer.sh
@@ -1,6 +1,8 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+# To install requirements: pip3 install -U openai-whisper
+
 # method1, inference from model hub
 
 # for more input type, please ref to readme.md
diff --git a/examples/industrial_data_pretraining/whisper/infer_from_local.sh b/examples/industrial_data_pretraining/whisper/infer_from_local.sh
index 885dfc6..4e12a3b 100644
--- a/examples/industrial_data_pretraining/whisper/infer_from_local.sh
+++ b/examples/industrial_data_pretraining/whisper/infer_from_local.sh
@@ -1,6 +1,8 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+# To install requirements: pip3 install -U openai-whisper
+
 # method2, inference from local model
 
 # for more input type, please ref to readme.md
diff --git a/examples/industrial_data_pretraining/whisper/infer_from_openai.sh b/examples/industrial_data_pretraining/whisper/infer_from_openai.sh
index 461d75e..7ce92e8 100644
--- a/examples/industrial_data_pretraining/whisper/infer_from_openai.sh
+++ b/examples/industrial_data_pretraining/whisper/infer_from_openai.sh
@@ -1,6 +1,8 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+# To install requirements: pip3 install -U openai-whisper
+
 # method1, inference from model hub
 
 # for more input type, please ref to readme.md
diff --git a/funasr/download/name_maps_from_hub.py b/funasr/download/name_maps_from_hub.py
index 5e252af..fc00843 100644
--- a/funasr/download/name_maps_from_hub.py
+++ b/funasr/download/name_maps_from_hub.py
@@ -1,13 +1,14 @@
 name_maps_ms = {
-    "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
-    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
-    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
-    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
-    "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
-    "cam++": "damo/speech_campplus_sv_zh-cn_16k-common",
+    "paraformer": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+    "paraformer-zh": "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+    "paraformer-en": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+    "paraformer-en-spk": "iic/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+    "paraformer-zh-streaming": "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
+    "fsmn-vad": "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+    "ct-punc": "iic/punc_ct-transformer_cn-en-common-vocab471067-large",
+    "ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+    "fa-zh": "iic/speech_timestamp_prediction-v1-16k-offline",
+    "cam++": "iic/speech_campplus_sv_zh-cn_16k-common",
     "Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
     "Whisper-large-v3": "iic/Whisper-large-v3",
     "Qwen-Audio": "Qwen/Qwen-Audio",
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index 572a34a..59c6e1d 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -623,7 +623,9 @@
     def __init__(self, model,
                  max_seq_len=512,
                  model_name='decoder',
-                 onnx: bool = True, ):
+                 onnx: bool = True,
+                 **kwargs
+                 ):
         super().__init__()
         # self.embed = model.embed #Embedding(model.embed, max_seq_len)
         from funasr.utils.torch_function import MakePadMask
@@ -752,6 +754,162 @@
         })
         return ret
 
+@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
+class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
+    def __init__(self, model,
+                 max_seq_len=512,
+                 model_name='decoder',
+                 onnx: bool = True, **kwargs):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+        self.model = model
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+
+        self.model = model
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+            self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+        if self.model.decoders2 is not None:
+            for i, d in enumerate(self.model.decoders2):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+                self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+        for i, d in enumerate(self.model.decoders3):
+            self.model.decoders3[i] = DecoderLayerSANMExport(d)
+        
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+    
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+        
+        return mask_3d_btd, mask_4d_bhlt
+    
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        *args,
+    ):
+        
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        
+        x = tgt
+        out_caches = list()
+        for i, decoder in enumerate(self.model.decoders):
+            in_cache = args[i]
+            x, tgt_mask, memory, memory_mask, out_cache = decoder(
+                x, tgt_mask, memory, memory_mask, cache=in_cache
+            )
+            out_caches.append(out_cache)
+        if self.model.decoders2 is not None:
+            for i, decoder in enumerate(self.model.decoders2):
+                in_cache = args[i + len(self.model.decoders)]
+                x, tgt_mask, memory, memory_mask, out_cache = decoder(
+                    x, tgt_mask, memory, memory_mask, cache=in_cache
+                )
+                out_caches.append(out_cache)
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+            x, tgt_mask, memory, memory_mask
+        )
+        x = self.after_norm(x)
+        x = self.output_layer(x)
+        
+        return x, out_caches
+    
+    def get_dummy_inputs(self, enc_size):
+        enc = torch.randn(2, 100, enc_size).type(torch.float32)
+        enc_len = torch.tensor([30, 100], dtype=torch.int32)
+        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
+        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        cache = [
+            torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
+                        dtype=torch.float32)
+            for _ in range(cache_num)
+        ]
+        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
+    
+    def get_input_names(self):
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
+               + ['in_cache_%d' % i for i in range(cache_num)]
+    
+    def get_output_names(self):
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        return ['logits', 'sample_ids'] \
+               + ['out_cache_%d' % i for i in range(cache_num)]
+    
+    def get_dynamic_axes(self):
+        ret = {
+            'enc': {
+                0: 'batch_size',
+                1: 'enc_length'
+            },
+            'acoustic_embeds': {
+                0: 'batch_size',
+                1: 'token_length'
+            },
+            'enc_len': {
+                0: 'batch_size',
+            },
+            'acoustic_embeds_len': {
+                0: 'batch_size',
+            },
+            
+        }
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        ret.update({
+            'in_cache_%d' % d: {
+                0: 'batch_size',
+            }
+            for d in range(cache_num)
+        })
+        ret.update({
+            'out_cache_%d' % d: {
+                0: 'batch_size',
+            }
+            for d in range(cache_num)
+        })
+        return ret
 
 
 @tables.register("decoder_classes", "ParaformerSANDecoder")
@@ -868,3 +1026,121 @@
         else:
             return x, olens
 
+@tables.register("decoder_classes", "ParaformerDecoderSANExport")
+class ParaformerDecoderSANExport(torch.nn.Module):
+    def __init__(self, model,
+                 max_seq_len=512,
+                 model_name='decoder',
+                 onnx: bool = True, ):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+        self.model = model
+
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+
+        self.model = model
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+
+        from funasr.models.transformer.decoder import DecoderLayerExport
+        from funasr.models.transformer.attention import MultiHeadedAttentionExport
+        
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.src_attn, MultiHeadedAttention):
+                d.src_attn = MultiHeadedAttentionExport(d.src_attn)
+            self.model.decoders[i] = DecoderLayerExport(d)
+        
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+    
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+        
+        return mask_3d_btd, mask_4d_bhlt
+    
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+        
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        
+        x = tgt
+        x, tgt_mask, memory, memory_mask = self.model.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        x = self.after_norm(x)
+        x = self.output_layer(x)
+        
+        return x, ys_in_lens
+    
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        memory = torch.randn(1, 100, enc_size)
+        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        cache = [
+            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+            for _ in range(cache_num)
+        ]
+        return (tgt, memory, pre_acoustic_embeds, cache)
+    
+    def is_optimizable(self):
+        return True
+    
+    def get_input_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+               + ['cache_%d' % i for i in range(cache_num)]
+    
+    def get_output_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['y'] \
+               + ['out_cache_%d' % i for i in range(cache_num)]
+    
+    def get_dynamic_axes(self):
+        ret = {
+            'tgt': {
+                0: 'tgt_batch',
+                1: 'tgt_length'
+            },
+            'memory': {
+                0: 'memory_batch',
+                1: 'memory_length'
+            },
+            'pre_acoustic_embeds': {
+                0: 'acoustic_embeds_batch',
+                1: 'acoustic_embeds_length',
+            }
+        }
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        ret.update({
+            'cache_%d' % d: {
+                0: 'cache_%d_batch' % d,
+                2: 'cache_%d_length' % d
+            }
+            for d in range(cache_num)
+        })
+        return ret
+    
\ No newline at end of file
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 4cf20de..cebbfc1 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -561,4 +561,134 @@
 
         return result, meta_data
 
+    def export(
+        self,
+        max_seq_len=512,
+        **kwargs,
+    ):
+    
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+    
+        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+        
+        if kwargs["decoder"] == "ParaformerSANMDecoder":
+            kwargs["decoder"] = "ParaformerSANMDecoderOnline"
+        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
+    
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+    
+        if is_onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+    
+        self.forward = self._export_forward
 
+        import copy
+        import types
+        encoder_model = copy.copy(self)
+        decoder_model = copy.copy(self)
+
+        # encoder
+        encoder_model.forward = types.MethodType(ParaformerStreaming._export_encoder_forward, encoder_model)
+        encoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_encoder_dummy_inputs, encoder_model)
+        encoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_encoder_input_names, encoder_model)
+        encoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_encoder_output_names, encoder_model)
+        encoder_model.export_dynamic_axes = types.MethodType(ParaformerStreaming.export_encoder_dynamic_axes, encoder_model)
+        encoder_model.export_name = types.MethodType(ParaformerStreaming.export_encoder_name, encoder_model)
+        
+        # decoder
+        decoder_model.forward = types.MethodType(ParaformerStreaming._export_decoder_forward, decoder_model)
+        decoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_decoder_dummy_inputs, decoder_model)
+        decoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_decoder_input_names, decoder_model)
+        decoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_decoder_output_names, decoder_model)
+        decoder_model.export_dynamic_axes = types.MethodType(ParaformerStreaming.export_decoder_dynamic_axes, decoder_model)
+        decoder_model.export_name = types.MethodType(ParaformerStreaming.export_decoder_name, decoder_model)
+    
+        return encoder_model, decoder_model
+
+    def _export_encoder_forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
+        # batch = to_device(batch, device=self.device)
+    
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        alphas, _ = self.predictor.forward_cnn(enc, mask)
+    
+        return enc, enc_len, alphas
+
+    def export_encoder_dummy_inputs(self):
+        speech = torch.randn(2, 30, 560)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        return (speech, speech_lengths)
+
+    def export_encoder_input_names(self):
+        return ['speech', 'speech_lengths']
+
+    def export_encoder_output_names(self):
+        return ['enc', 'enc_len', 'alphas']
+
+    def export_encoder_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'enc': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'enc_len': {
+                0: 'batch_size',
+            },
+            'alphas': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+        }
+    
+    def export_encoder_name(self):
+        return "model.onnx"
+    
+    def _export_decoder_forward(
+        self,
+        enc: torch.Tensor,
+        enc_len: torch.Tensor,
+        acoustic_embeds: torch.Tensor,
+        acoustic_embeds_len: torch.Tensor,
+        *args,
+    ):
+        decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
+        sample_ids = decoder_out.argmax(dim=-1)
+    
+        return decoder_out, sample_ids, out_caches
+
+    def export_decoder_dummy_inputs(self):
+        dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.encoder._output_size)
+        return dummy_inputs
+
+    def export_decoder_input_names(self):
+    
+        return self.decoder.get_input_names()
+
+    def export_decoder_output_names(self):
+    
+        return self.decoder.get_output_names()
+
+    def export_decoder_dynamic_axes(self):
+        return self.decoder.get_dynamic_axes()
+    def export_decoder_name(self):
+        return "decoder.onnx"
\ No newline at end of file
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index c3a2f94..5f91268 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -831,6 +831,3 @@
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
         return att_outs
-
-
-
diff --git a/funasr/models/sanm/attention_export.py b/funasr/models/sanm/attention_export.py
deleted file mode 100644
index 435dd1e..0000000
--- a/funasr/models/sanm/attention_export.py
+++ /dev/null
@@ -1,114 +0,0 @@
-import os
-import math
-
-import torch
-import torch.nn as nn
-
-
-
-
-
-
-
-
-
-
-class OnnxMultiHeadedAttention(nn.Module):
-    def __init__(self, model):
-        super().__init__()
-        self.d_k = model.d_k
-        self.h = model.h
-        self.linear_q = model.linear_q
-        self.linear_k = model.linear_k
-        self.linear_v = model.linear_v
-        self.linear_out = model.linear_out
-        self.attn = None
-        self.all_head_size = self.h * self.d_k
-    
-    def forward(self, query, key, value, mask):
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
-        return self.forward_attention(v, scores, mask)
-
-    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
-        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
-        x = x.view(new_x_shape)
-        return x.permute(0, 2, 1, 3)
-
-    def forward_qkv(self, query, key, value):
-        q = self.linear_q(query)
-        k = self.linear_k(key)
-        v = self.linear_v(value)
-        q = self.transpose_for_scores(q)
-        k = self.transpose_for_scores(k)
-        v = self.transpose_for_scores(v)
-        return q, k, v
-    
-    def forward_attention(self, value, scores, mask):
-        scores = scores + mask
-
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
-        
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
-        context_layer = context_layer.view(new_context_layer_shape)
-        return self.linear_out(context_layer)  # (batch, time1, d_model)
-
-
-class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
-    def __init__(self, model):
-        super().__init__(model)
-        self.linear_pos = model.linear_pos
-        self.pos_bias_u = model.pos_bias_u
-        self.pos_bias_v = model.pos_bias_v
-    
-    def forward(self, query, key, value, pos_emb, mask):
-        q, k, v = self.forward_qkv(query, key, value)
-        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
-
-        p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
-
-        # (batch, head, time1, d_k)
-        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
-        # (batch, head, time1, d_k)
-        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
-
-        # compute attention score
-        # first compute matrix a and matrix c
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        # (batch, head, time1, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
-
-        # compute matrix b and matrix d
-        # (batch, head, time1, time1)
-        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
-        matrix_bd = self.rel_shift(matrix_bd)
-
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
-
-        return self.forward_attention(v, scores, mask)
-
-    def rel_shift(self, x):
-        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
-        x_padded = torch.cat([zero_pad, x], dim=-1)
-
-        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
-        x = x_padded[:, :, 1:].view_as(x)[
-            :, :, :, : x.size(-1) // 2 + 1
-        ]  # only keep the positions from 0 to time2
-        return x
-
-    def forward_attention(self, value, scores, mask):
-        scores = scores + mask
-
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
-        
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
-        context_layer = context_layer.view(new_context_layer_shape)
-        return self.linear_out(context_layer)  # (batch, time1, d_model)
-        
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index 561179b..f0a3722 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -484,7 +484,7 @@
 
         return x, mask
 
-
+@tables.register("encoder_classes", "SANMEncoderChunkOptExport")
 @tables.register("encoder_classes", "SANMEncoderExport")
 class SANMEncoderExport(nn.Module):
     def __init__(
diff --git a/funasr/models/transformer/attention.py b/funasr/models/transformer/attention.py
index f09d642..695023d 100644
--- a/funasr/models/transformer/attention.py
+++ b/funasr/models/transformer/attention.py
@@ -118,6 +118,106 @@
         return self.forward_attention(v, scores, mask)
 
 
+class MultiHeadedAttentionExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_q = model.linear_q
+        self.linear_k = model.linear_k
+        self.linear_v = model.linear_v
+        self.linear_out = model.linear_out
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+    
+    def forward(self, query, key, value, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+    
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+    
+    def forward_qkv(self, query, key, value):
+        q = self.linear_q(query)
+        k = self.linear_k(key)
+        v = self.linear_v(value)
+        q = self.transpose_for_scores(q)
+        k = self.transpose_for_scores(k)
+        v = self.transpose_for_scores(v)
+        return q, k, v
+    
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+        
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+class RelPosMultiHeadedAttentionExport(MultiHeadedAttentionExport):
+    def __init__(self, model):
+        super().__init__(model)
+        self.linear_pos = model.linear_pos
+        self.pos_bias_u = model.pos_bias_u
+        self.pos_bias_v = model.pos_bias_v
+    
+    def forward(self, query, key, value, pos_emb, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+        
+        p = self.transpose_for_scores(self.linear_pos(pos_emb))  # (batch, head, time1, d_k)
+        
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+        
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+        
+        # compute matrix b and matrix d
+        # (batch, head, time1, time1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+        
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+        
+        return self.forward_attention(v, scores, mask)
+    
+    def rel_shift(self, x):
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+        
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)[
+            :, :, :, : x.size(-1) // 2 + 1
+            ]  # only keep the positions from 0 to time2
+        return x
+    
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+        
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
 class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
     """Multi-Head Attention layer with relative position encoding (old version).
 
diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
index 820de4a..1e88a25 100644
--- a/funasr/models/transformer/decoder.py
+++ b/funasr/models/transformer/decoder.py
@@ -150,6 +150,35 @@
         return x, tgt_mask, memory, memory_mask
 
 
+class DecoderLayerExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.src_attn = model.src_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.norm3 = model.norm3
+    
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt_q = tgt
+        tgt_q_mask = tgt_mask
+        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
+        
+        residual = x
+        x = self.norm2(x)
+        
+        x = residual + self.src_attn(x, memory, memory, memory_mask)
+        
+        residual = x
+        x = self.norm3(x)
+        x = residual + self.feed_forward(x)
+        
+        return x, tgt_mask, memory, memory_mask
+
+
 class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
     """Base class of Transfomer decoder module.
 
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 7afd083..e047db9 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -5,6 +5,7 @@
 import os.path
 from pathlib import Path
 from typing import List, Union, Tuple
+import json
 
 import copy
 import librosa
@@ -55,25 +56,24 @@
         if not os.path.exists(model_file):
             print(".onnx is not exist, begin to export onnx")
             try:
-                from funasr.export.export_model import ModelExport
+                from funasr import AutoModel
             except:
                 raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                       "\npip3 install -U funasr\n" \
                       "For the users in China, you could install with the command:\n" \
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
-            export_model = ModelExport(
-                cache_dir=cache_dir,
-                onnx=True,
-                device="cpu",
-                quant=quantize,
-            )
-            export_model.export(model_dir)
+
+            model = AutoModel(model=cache_dir)
+            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
             
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
         config = read_yaml(config_file)
+        token_list = os.path.join(model_dir, 'tokens.json')
+        with open(token_list, 'r', encoding='utf-8') as f:
+            token_list = json.load(f)
 
-        self.converter = TokenIDConverter(config['token_list'])
+        self.converter = TokenIDConverter(token_list)
         self.tokenizer = CharTokenizer()
         self.frontend = WavFrontend(
             cmvn_file=cmvn_file,
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
index b39e2f5..7da5afc 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -3,7 +3,7 @@
 import os.path
 from pathlib import Path
 from typing import List, Union, Tuple
-
+import json
 import copy
 import librosa
 import numpy as np
@@ -48,25 +48,24 @@
         if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
             print(".onnx is not exist, begin to export onnx")
             try:
-                from funasr.export.export_model import ModelExport
+                from funasr import AutoModel
             except:
                 raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                       "\npip3 install -U funasr\n" \
                       "For the users in China, you could install with the command:\n" \
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
-            export_model = ModelExport(
-                cache_dir=cache_dir,
-                onnx=True,
-                device="cpu",
-                quant=quantize,
-            )
-            export_model.export(model_dir)
+
+            model = AutoModel(model=cache_dir)
+            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
 
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
         config = read_yaml(config_file)
+        token_list = os.path.join(model_dir, 'tokens.json')
+        with open(token_list, 'r', encoding='utf-8') as f:
+            token_list = json.load(f)
 
-        self.converter = TokenIDConverter(config['token_list'])
+        self.converter = TokenIDConverter(token_list)
         self.tokenizer = CharTokenizer()
         self.frontend = WavFrontendOnline(
             cmvn_file=cmvn_file,
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 86a276f..4e1014f 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -6,7 +6,7 @@
 from pathlib import Path
 from typing import List, Union, Tuple
 import numpy as np
-
+import json
 from .utils.utils import (ONNXRuntimeError,
                           OrtInferSession, get_logger,
                           read_yaml)
@@ -48,24 +48,23 @@
         if not os.path.exists(model_file):
             print(".onnx is not exist, begin to export onnx")
             try:
-                from funasr.export.export_model import ModelExport
+                from funasr import AutoModel
             except:
                 raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                       "\npip3 install -U funasr\n" \
                       "For the users in China, you could install with the command:\n" \
                       "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
-            export_model = ModelExport(
-                cache_dir=cache_dir,
-                onnx=True,
-                device="cpu",
-                quant=quantize,
-            )
-            export_model.export(model_dir)
+
+            model = AutoModel(model=cache_dir)
+            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
             
         config_file = os.path.join(model_dir, 'punc.yaml')
         config = read_yaml(config_file)
+        token_list = os.path.join(model_dir, 'tokens.json')
+        with open(token_list, 'r', encoding='utf-8') as f:
+            token_list = json.load(f)
 
-        self.converter = TokenIDConverter(config['token_list'])
+        self.converter = TokenIDConverter(token_list)
         self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
         self.batch_size = 1
         self.punc_list = config['punc_list']
diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index 5892995..af32b1d 100644
--- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -54,19 +54,15 @@
 		if not os.path.exists(model_file):
 			print(".onnx is not exist, begin to export onnx")
 			try:
-				from funasr.export.export_model import ModelExport
+				from funasr import AutoModel
 			except:
 				raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
 				      "\npip3 install -U funasr\n" \
 				      "For the users in China, you could install with the command:\n" \
 				      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
-			export_model = ModelExport(
-				cache_dir=cache_dir,
-				onnx=True,
-				device="cpu",
-				quant=quantize,
-			)
-			export_model.export(model_dir)
+			
+			model = AutoModel(model=cache_dir)
+			model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
 		config_file = os.path.join(model_dir, 'vad.yaml')
 		cmvn_file = os.path.join(model_dir, 'vad.mvn')
 		config = read_yaml(config_file)
@@ -222,19 +218,16 @@
 		if not os.path.exists(model_file):
 			print(".onnx is not exist, begin to export onnx")
 			try:
-				from funasr.export.export_model import ModelExport
+				from funasr import AutoModel
 			except:
 				raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
 				      "\npip3 install -U funasr\n" \
 				      "For the users in China, you could install with the command:\n" \
 				      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
-			export_model = ModelExport(
-				cache_dir=cache_dir,
-				onnx=True,
-				device="cpu",
-				quant=quantize,
-			)
-			export_model.export(model_dir)
+			
+			model = AutoModel(model=cache_dir)
+			model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
+			
 		config_file = os.path.join(model_dir, 'vad.yaml')
 		cmvn_file = os.path.join(model_dir, 'vad.mvn')
 		config = read_yaml(config_file)

--
Gitblit v1.9.1