From 9d48230c4f8f25bf88c5d6105f97370a36c9cf43 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 10:48:50 +0800
Subject: [PATCH] export onnx (#1457)

---
 funasr/models/paraformer/model.py                                       |   12 
 examples/industrial_data_pretraining/bicif_paraformer/export.sh         |   28 +
 funasr/bin/export.py                                                    |    3 
 funasr/models/ct_transformer/model.py                                   |   54 +++
 examples/industrial_data_pretraining/paraformer/export.py               |   17 
 setup.py                                                                |    1 
 examples/industrial_data_pretraining/ct_transformer/export.py           |   26 +
 examples/industrial_data_pretraining/bicif_paraformer/export.py         |   26 +
 funasr/utils/export_utils.py                                            |    2 
 examples/industrial_data_pretraining/paraformer/README_zh.md            |   14 
 funasr/models/bicif_paraformer/model.py                                 |   88 ++++
 funasr/models/sanm/attention_export.py                                  |  114 ++++++
 examples/industrial_data_pretraining/paraformer/export.sh               |   14 
 funasr/models/ct_transformer_streaming/encoder.py                       |  105 +++++
 examples/industrial_data_pretraining/ct_transformer/export.sh           |   28 +
 funasr/models/sanm/attention.py                                         |  141 +++++++
 examples/industrial_data_pretraining/ct_transformer_streaming/export.py |   26 +
 funasr/models/fsmn_vad_streaming/model.py                               |    7 
 funasr/auto/auto_model.py                                               |    4 
 funasr/models/paraformer/cif_predictor.py                               |    2 
 funasr/models/bicif_paraformer/cif_predictor.py                         |  163 +++++++++
 funasr/models/fsmn_vad_streaming/encoder.py                             |   60 ++
 examples/industrial_data_pretraining/fsmn_vad_streaming/export.py       |   16 
 funasr/models/paraformer/decoder.py                                     |    5 
 funasr/models/ct_transformer_streaming/model.py                         |   65 +++
 examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh       |   18 
 examples/industrial_data_pretraining/ct_transformer_streaming/export.sh |   28 +
 27 files changed, 1,029 insertions(+), 38 deletions(-)

diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
new file mode 100644
index 0000000..78e7295
--- /dev/null
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+                  model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+
+# # method2, inference from local path
+# from funasr import AutoModel
+#
+# wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+#
+# model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+#
+# res = model.export(input=wav_file, type="onnx", quantize=False)
+# print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.sh b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
new file mode 100644
index 0000000..bc20a90
--- /dev/null
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
@@ -0,0 +1,28 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_revision="v2.0.4"
+
+python -m funasr.bin.export \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.py b/examples/industrial_data_pretraining/ct_transformer/export.py
new file mode 100644
index 0000000..3321525
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer/export.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
+
+model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+                  model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+
+# method2, inference from local path
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.sh b/examples/industrial_data_pretraining/ct_transformer/export.sh
new file mode 100644
index 0000000..f7849a1
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer/export.sh
@@ -0,0 +1,28 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model_revision="v2.0.4"
+
+python -m funasr.bin.export \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/export.py b/examples/industrial_data_pretraining/ct_transformer_streaming/export.py
new file mode 100644
index 0000000..4e50501
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/export.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
+
+model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
+                  model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+
+# method2, inference from local path
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/export.sh b/examples/industrial_data_pretraining/ct_transformer_streaming/export.sh
new file mode 100644
index 0000000..118afbb
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/export.sh
@@ -0,0 +1,28 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model_revision="v2.0.4"
+
+python -m funasr.bin.export \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
index d259104..2e09523 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
@@ -3,10 +3,24 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+
+# method1, inference from model hub
+
 from funasr import AutoModel
 wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
 
-model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4", export_model=True)
+model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+# method2, inference from local path
+
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch")
 
 res = model.export(input=wav_file, type="onnx", quantize=False)
 print(res)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
index 0bb4617..911a1a1 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
@@ -6,14 +6,24 @@
 export HYDRA_FULL_ERROR=1
 
 
-model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
 model_revision="v2.0.4"
 
-python funasr/bin/export.py \
+python -m funasr.bin.export \
 ++model=${model} \
 ++model_revision=${model_revision} \
 ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
 ++quantize=false \
-++device="cpu" \
-++debug=false
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/README_zh.md b/examples/industrial_data_pretraining/paraformer/README_zh.md
index f9ab616..38a4455 100644
--- a/examples/industrial_data_pretraining/paraformer/README_zh.md
+++ b/examples/industrial_data_pretraining/paraformer/README_zh.md
@@ -78,4 +78,18 @@
 
 ```bash
 tensorboard --logdir /xxxx/FunASR/examples/industrial_data_pretraining/paraformer/outputs/log/tensorboard
+```
+
+
+## 瀵煎嚭onnx
+
+```python
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+                  model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
 ```
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
index 613c3a9..43b3c18 100644
--- a/examples/industrial_data_pretraining/paraformer/export.py
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -3,11 +3,26 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
+
+# method1, inference from model hub
+
+
 from funasr import AutoModel
 wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
 
 model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-                  model_revision="v2.0.4", export_model=True)
+                  model_revision="v2.0.4")
 
 res = model.export(input=wav_file, type="onnx", quantize=False)
 print(res)
+
+
+# method2, inference from local path
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.sh b/examples/industrial_data_pretraining/paraformer/export.sh
index 9f45a5a..c67dca8 100644
--- a/examples/industrial_data_pretraining/paraformer/export.sh
+++ b/examples/industrial_data_pretraining/paraformer/export.sh
@@ -8,11 +8,23 @@
 model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
 model_revision="v2.0.4"
 
-python funasr/bin/export.py \
+
+python -m funasr.bin.export \
 ++model=${model} \
 ++model_revision=${model_revision} \
 ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
 ++type="onnx" \
 ++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
 ++device="cpu" \
 ++debug=false
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index d7b6cb9..c4bab03 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -100,9 +100,7 @@
     def __init__(self, **kwargs):
         if not kwargs.get("disable_log", True):
             tables.print()
-        if kwargs.get("export_model", False):
-            os.environ['EXPORTING_MODEL'] = 'TRUE'
-            
+
         model, kwargs = self.build_model(**kwargs)
         
         # if vad_model is not None, build vad model else None
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
index 68acc17..7d47664 100644
--- a/funasr/bin/export.py
+++ b/funasr/bin/export.py
@@ -24,8 +24,9 @@
     if kwargs.get("debug", False):
         import pdb; pdb.set_trace()
 
+
+    model = AutoModel(**kwargs)
     
-    model = AutoModel(export_model=True, **kwargs)
     res = model.export(input=kwargs.get("input", None),
                        type=kwargs.get("type", "onnx"),
                        quantize=kwargs.get("quantize", False),
diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py
index e7b3ba9..2cdbc16 100644
--- a/funasr/models/bicif_paraformer/cif_predictor.py
+++ b/funasr/models/bicif_paraformer/cif_predictor.py
@@ -336,3 +336,166 @@
         predictor_alignments = index_div_bool_zeros_count_tile_out
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
         return predictor_alignments.detach(), predictor_alignments_length.detach()
+
+@tables.register("predictor_classes", "CifPredictorV3Export")
+class CifPredictorV3Export(torch.nn.Module):
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        
+        self.pad = model.pad
+        self.cif_conv1d = model.cif_conv1d
+        self.cif_output = model.cif_output
+        self.threshold = model.threshold
+        self.smooth_factor = model.smooth_factor
+        self.noise_threshold = model.noise_threshold
+        self.tail_threshold = model.tail_threshold
+        
+        self.upsample_times = model.upsample_times
+        self.upsample_cnn = model.upsample_cnn
+        self.blstm = model.blstm
+        self.cif_output2 = model.cif_output2
+        self.smooth_factor2 = model.smooth_factor2
+        self.noise_threshold2 = model.noise_threshold2
+    
+    def forward(self, hidden: torch.Tensor,
+                mask: torch.Tensor,
+                ):
+        h = hidden
+        context = h.transpose(1, 2)
+        queries = self.pad(context)
+        output = torch.relu(self.cif_conv1d(queries))
+        output = output.transpose(1, 2)
+        
+        output = self.cif_output(output)
+        alphas = torch.sigmoid(output)
+        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+        mask = mask.transpose(-1, -2).float()
+        alphas = alphas * mask
+        alphas = alphas.squeeze(-1)
+        token_num = alphas.sum(-1)
+        
+        mask = mask.squeeze(-1)
+        hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+        acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+        
+        return acoustic_embeds, token_num, alphas, cif_peak
+    
+    def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
+        h = hidden
+        b = hidden.shape[0]
+        context = h.transpose(1, 2)
+        
+        # generate alphas2
+        _output = context
+        output2 = self.upsample_cnn(_output)
+        output2 = output2.transpose(1, 2)
+        output2, (_, _) = self.blstm(output2)
+        alphas2 = torch.sigmoid(self.cif_output2(output2))
+        alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
+        
+        mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
+        mask = mask.unsqueeze(-1)
+        alphas2 = alphas2 * mask
+        alphas2 = alphas2.squeeze(-1)
+        _token_num = alphas2.sum(-1)
+        alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
+        # upsampled alphas and cif_peak
+        us_alphas = alphas2
+        us_cif_peak = cif_wo_hidden_export(us_alphas, self.threshold - 1e-4)
+        return us_alphas, us_cif_peak
+    
+    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+        b, t, d = hidden.size()
+        tail_threshold = self.tail_threshold
+        
+        zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+        ones_t = torch.ones_like(zeros_t)
+        
+        mask_1 = torch.cat([mask, zeros_t], dim=1)
+        mask_2 = torch.cat([ones_t, mask], dim=1)
+        mask = mask_2 - mask_1
+        tail_threshold = mask * tail_threshold
+        alphas = torch.cat([alphas, zeros_t], dim=1)
+        alphas = torch.add(alphas, tail_threshold)
+        
+        zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+        hidden = torch.cat([hidden, zeros], dim=1)
+        token_num = alphas.sum(dim=-1)
+        token_num_floor = torch.floor(token_num)
+        
+        return hidden, alphas, token_num_floor
+
+
+@torch.jit.script
+def cif_export(hidden, alphas, threshold: float):
+    batch_size, len_time, hidden_size = hidden.size()
+    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+    
+    # loop varss
+    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
+    frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
+    # intermediate vars along time
+    list_fires = []
+    list_frames = []
+    
+    for t in range(len_time):
+        alpha = alphas[:, t]
+        distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
+        
+        integrate += alpha
+        list_fires.append(integrate)
+        
+        fire_place = integrate >= threshold
+        integrate = torch.where(fire_place,
+                                integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
+                                integrate)
+        cur = torch.where(fire_place,
+                          distribution_completion,
+                          alpha)
+        remainds = alpha - cur
+        
+        frame += cur[:, None] * hidden[:, t, :]
+        list_frames.append(frame)
+        frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+                            remainds[:, None] * hidden[:, t, :],
+                            frame)
+    
+    fires = torch.stack(list_fires, 1)
+    frames = torch.stack(list_frames, 1)
+    
+    fire_idxs = fires >= threshold
+    frame_fires = torch.zeros_like(hidden)
+    max_label_len = frames[0, fire_idxs[0]].size(0)
+    for b in range(batch_size):
+        frame_fire = frames[b, fire_idxs[b]]
+        frame_len = frame_fire.size(0)
+        frame_fires[b, :frame_len, :] = frame_fire
+        
+        if frame_len >= max_label_len:
+            max_label_len = frame_len
+    frame_fires = frame_fires[:, :max_label_len, :]
+    return frame_fires, fires
+
+
+@torch.jit.script
+def cif_wo_hidden_export(alphas, threshold: float):
+    batch_size, len_time = alphas.size()
+    
+    # loop varss
+    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
+    # intermediate vars along time
+    list_fires = []
+    
+    for t in range(len_time):
+        alpha = alphas[:, t]
+        
+        integrate += alpha
+        list_fires.append(integrate)
+        
+        fire_place = integrate >= threshold
+        integrate = torch.where(fire_place,
+                                integrate - torch.ones([batch_size], device=alphas.device) * threshold,
+                                integrate)
+    
+    fires = torch.stack(list_fires, 1)
+    return fires
\ No newline at end of file
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index 696cd56..eb7318b 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -341,4 +341,90 @@
                     result_i = {"key": key[i], "token_int": token_int}
                 results.append(result_i)
         
-        return results, meta_data
\ No newline at end of file
+        return results, meta_data
+
+    def export(
+        self,
+        max_seq_len=512,
+        **kwargs,
+    ):
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+    
+        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+    
+        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
+    
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+    
+        if is_onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+    
+        self.forward = self._export_forward
+    
+        return self
+
+    def _export_forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        # batch = to_device(batch, device=self.device)
+    
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+        pre_token_length = pre_token_length.round().type(torch.int32)
+    
+        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+    
+        # get predicted timestamps
+        us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
+    
+        return decoder_out, pre_token_length, us_alphas, us_cif_peak
+
+    def export_dummy_inputs(self):
+        speech = torch.randn(2, 30, 560)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        return (speech, speech_lengths)
+
+    def export_input_names(self):
+        return ['speech', 'speech_lengths']
+
+    def export_output_names(self):
+        return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
+
+    def export_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+            'us_alphas': {
+                0: 'batch_size',
+                1: 'alphas_length'
+            },
+            'us_cif_peak': {
+                0: 'batch_size',
+                1: 'alphas_length'
+            },
+        }
+
+    def export_name(self, ):
+        return "model.onnx"
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 45f5746..31b8c27 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -365,3 +365,57 @@
         results.append(result_i)
         return results, meta_data
 
+    def export(
+        self,
+        **kwargs,
+    ):
+
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+
+        self.forward = self._export_forward
+        
+        return self
+
+    def _export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(inputs)
+        h, _ = self.encoder(x, text_lengths)
+        y = self.decoder(h)
+        return y
+
+    def export_dummy_inputs(self):
+        length = 120
+        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
+        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+        return (text_indexes, text_lengths)
+
+    def export_input_names(self):
+        return ['inputs', 'text_lengths']
+
+    def export_output_names(self):
+        return ['logits']
+
+    def export_dynamic_axes(self):
+        return {
+            'inputs': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'text_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+        }
+    def export_name(self):
+        return "model.onnx"
\ No newline at end of file
diff --git a/funasr/models/ct_transformer_streaming/encoder.py b/funasr/models/ct_transformer_streaming/encoder.py
index 95e2a4b..badf5f6 100644
--- a/funasr/models/ct_transformer_streaming/encoder.py
+++ b/funasr/models/ct_transformer_streaming/encoder.py
@@ -371,3 +371,108 @@
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
+
+
+class EncoderLayerSANMExport(torch.nn.Module):
+    def __init__(
+        self,
+        model,
+    ):
+        """Construct an EncoderLayer object."""
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.in_size = model.in_size
+        self.size = model.size
+
+    def forward(self, x, mask):
+
+        residual = x
+        x = self.norm1(x)
+        x = self.self_attn(x, mask)
+        if self.in_size == self.size:
+            x = x + residual
+        residual = x
+        x = self.norm2(x)
+        x = self.feed_forward(x)
+        x = x + residual
+
+        return x, mask
+
+@tables.register("encoder_classes", "SANMVadEncoderExport")
+class SANMVadEncoderExport(torch.nn.Module):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name='encoder',
+        onnx: bool = True,
+    ):
+        super().__init__()
+        self.embed = model.embed
+        self.model = model
+        self._output_size = model._output_size
+        
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+        
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
+        
+        if hasattr(model, 'encoders0'):
+            for i, d in enumerate(self.model.encoders0):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
+                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+                self.model.encoders0[i] = EncoderLayerSANMExport(d)
+        
+        for i, d in enumerate(self.model.encoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
+                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+            self.model.encoders[i] = EncoderLayerSANMExport(d)
+        
+        
+    def prepare_mask(self, mask, sub_masks):
+        mask_3d_btd = mask[:, :, None]
+        mask_4d_bhlt = (1 - sub_masks) * -10000.0
+        
+        return mask_3d_btd, mask_4d_bhlt
+    
+    def forward(self,
+                speech: torch.Tensor,
+                speech_lengths: torch.Tensor,
+                vad_masks: torch.Tensor,
+                sub_masks: torch.Tensor,
+                ):
+        speech = speech * self._output_size ** 0.5
+        mask = self.make_pad_mask(speech_lengths)
+        vad_masks = self.prepare_mask(mask, vad_masks)
+        mask = self.prepare_mask(mask, sub_masks)
+        
+        if self.embed is None:
+            xs_pad = speech
+        else:
+            xs_pad = self.embed(speech)
+        
+        encoder_outs = self.model.encoders0(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        
+        # encoder_outs = self.model.encoders(xs_pad, mask)
+        for layer_idx, encoder_layer in enumerate(self.model.encoders):
+            if layer_idx == len(self.model.encoders) - 1:
+                mask = vad_masks
+            encoder_outs = encoder_layer(xs_pad, mask)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        
+        xs_pad = self.model.after_norm(xs_pad)
+        
+        return xs_pad, speech_lengths
+    
+    def get_output_size(self):
+        return self.model.encoders[0].size
\ No newline at end of file
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index 217767a..a9b2efb 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -173,3 +173,68 @@
     
         return results, meta_data
 
+    def export(
+        self,
+        **kwargs,
+    ):
+    
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+    
+        self.forward = self._export_forward
+    
+        return self
+
+    def _export_forward(self, inputs: torch.Tensor,
+                text_lengths: torch.Tensor,
+                vad_indexes: torch.Tensor,
+                sub_masks: torch.Tensor,
+                ):
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(inputs)
+        # mask = self._target_mask(input)
+        h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
+        y = self.decoder(h)
+        return y
+
+    def export_dummy_inputs(self):
+        length = 120
+        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
+        text_lengths = torch.tensor([length], dtype=torch.int32)
+        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
+        sub_masks = torch.ones(length, length, dtype=torch.float32)
+        sub_masks = torch.tril(sub_masks).type(torch.float32)
+        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
+
+    def export_input_names(self):
+        return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
+
+    def export_output_names(self):
+        return ['logits']
+
+    def export_dynamic_axes(self):
+        return {
+            'inputs': {
+                1: 'feats_length'
+            },
+            'vad_masks': {
+                2: 'feats_length1',
+                3: 'feats_length2'
+            },
+            'sub_masks': {
+                2: 'feats_length1',
+                3: 'feats_length2'
+            },
+            'logits': {
+                1: 'logits_length'
+            },
+        }
+    def export_name(self):
+        return "model.onnx"
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
index e7c0e8b..bc51a6f 100755
--- a/funasr/models/fsmn_vad_streaming/encoder.py
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -194,7 +194,7 @@
             output_affine_dim: int,
             output_dim: int
     ):
-        super(FSMN, self).__init__()
+        super().__init__()
 
         self.input_dim = input_dim
         self.input_affine_dim = input_affine_dim
@@ -213,12 +213,6 @@
         self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
         self.softmax = nn.Softmax(dim=-1)
         
-        # export onnx or torchscripts
-        if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE':
-            for i, d in enumerate(self.fsmn):
-                if isinstance(d, BasicBlock):
-                    self.fsmn[i] = BasicBlock_export(d)
-
     def fuse_modules(self):
         pass
 
@@ -244,10 +238,49 @@
 
         return x7
 
-    def export_forward(
-            self,
-            input: torch.Tensor,
-            *args,
+
+@tables.register("encoder_classes", "FSMNExport")
+class FSMNExport(nn.Module):
+    def __init__(
+        self, model, **kwargs,
+    ):
+        super().__init__()
+        
+        # self.input_dim = input_dim
+        # self.input_affine_dim = input_affine_dim
+        # self.fsmn_layers = fsmn_layers
+        # self.linear_dim = linear_dim
+        # self.proj_dim = proj_dim
+        # self.output_affine_dim = output_affine_dim
+        # self.output_dim = output_dim
+        #
+        # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+        # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+        # self.relu = RectifiedLinear(linear_dim, linear_dim)
+        # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+        #                         range(fsmn_layers)])
+        # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+        # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+        # self.softmax = nn.Softmax(dim=-1)
+        self.in_linear1 = model.in_linear1
+        self.in_linear2 = model.in_linear2
+        self.relu = model.relu
+        # self.fsmn = model.fsmn
+        self.out_linear1 = model.out_linear1
+        self.out_linear2 = model.out_linear2
+        self.softmax = model.softmax
+        self.fsmn = model.fsmn
+        for i, d in enumerate(model.fsmn):
+            if isinstance(d, BasicBlock):
+                self.fsmn[i] = BasicBlock_export(d)
+    
+    def fuse_modules(self):
+        pass
+    
+    def forward(
+        self,
+        input: torch.Tensor,
+        *args,
     ):
         """
         Args:
@@ -255,7 +288,7 @@
             in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
             {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
         """
-
+        
         x = self.in_linear1(input)
         x = self.in_linear2(x)
         x = self.relu(x)
@@ -268,9 +301,10 @@
         x = self.out_linear1(x)
         x = self.out_linear2(x)
         x = self.softmax(x)
-
+        
         return x, out_caches
 
+
 '''
 one deep fsmn layer
 dimproj:                projection dimension, input and output dimension of memory blocks
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 9649ee8..c3063b0 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -644,12 +644,17 @@
 		return results, meta_data
 	
 	def export(self, **kwargs):
+		is_onnx = kwargs.get("type", "onnx") == "onnx"
+		encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+		self.encoder = encoder_class(self.encoder, onnx=is_onnx)
 		self.forward = self._export_forward
 		
 		return self
 		
 	def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
-		scores, out_caches = self.encoder.export_forward(feats, *args)
+		
+		scores, out_caches = self.encoder(feats, *args)
+		
 		return scores, out_caches
 	
 	def export_dummy_inputs(self, data_in=None, frame=30):
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index ddcfb5a..d538e21 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -376,7 +376,7 @@
         return predictor_alignments.detach(), predictor_alignments_length.detach()
 
 @tables.register("predictor_classes", "CifPredictorV2Export")
-class CifPredictorV2(torch.nn.Module):
+class CifPredictorV2Export(torch.nn.Module):
     def __init__(self, model, **kwargs):
         super().__init__()
         
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index ce018f4..572a34a 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -635,8 +635,9 @@
         else:
             self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
         
-        from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
-        from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+        
         
         for i, d in enumerate(self.model.decoders):
             if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 8739ed5..586d72d 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -554,21 +554,23 @@
         max_seq_len=512,
         **kwargs,
     ):
-        onnx = kwargs.get("onnx", True)
+        
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
         encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
-        self.encoder = encoder_class(self.encoder, onnx=onnx)
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
         
         predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
-        self.predictor = predictor_class(self.predictor, onnx=onnx)
+        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
 
 
         decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
-        self.decoder = decoder_class(self.decoder, onnx=onnx)
+        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
         
         from funasr.utils.torch_function import MakePadMask
         from funasr.utils.torch_function import sequence_mask
         
-        if onnx:
+        
+        if is_onnx:
             self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
         else:
             self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 09a1f07..c3a2f94 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -17,6 +17,24 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 import funasr.models.lora.layers as lora
 
+
+def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
+    x = x * mask
+    x = x.transpose(1, 2)
+    if cache is None:
+        x = pad_fn(x)
+    else:
+        x = torch.cat((cache, x), dim=2)
+        cache = x[:, :, -(kernel_size-1):]
+    return x, cache
+
+
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
+    import torch.fx
+    torch.fx.wrap('preprocess_for_attn')
+
+
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
 
@@ -362,6 +380,65 @@
         return self.linear_out(context_layer)  # (batch, time1, d_model)
 
 
+class MultiHeadedAttentionSANMExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_out = model.linear_out
+        self.linear_q_k_v = model.linear_q_k_v
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, mask):
+        mask_3d_btd, mask_4d_bhlt = mask
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+        q_h = q_h * self.d_k**(-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+        return att_outs + fsmn_memory
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x):
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = self.transpose_for_scores(q)
+        k_h = self.transpose_for_scores(k)
+        v_h = self.transpose_for_scores(v)
+        return q_h, k_h, v_h, v
+
+    def forward_fsmn(self, inputs, mask):
+        # b, t, d = inputs.size()
+        # mask = torch.reshape(mask, (b, -1, 1))
+        inputs = inputs * mask
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x = x + inputs
+        x = x * mask
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
 
 class MultiHeadedAttentionSANMDecoder(nn.Module):
     """Multi-Head Attention layer.
@@ -375,7 +452,7 @@
 
     def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
         """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANMDecoder, self).__init__()
+        super().__init__()
 
         self.dropout = nn.Dropout(p=dropout_rate)
 
@@ -440,6 +517,24 @@
             x = x * mask
         return x, cache
 
+class MultiHeadedAttentionSANMDecoderExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+        self.kernel_size = model.kernel_size
+        self.attn = None
+
+    def forward(self, inputs, mask, cache=None):
+        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+
+        x = x + inputs
+        x = x * mask
+        return x, cache
+
+
 class MultiHeadedAttentionCrossAtt(nn.Module):
     """Multi-Head Attention layer.
 
@@ -452,7 +547,7 @@
 
     def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
         """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionCrossAtt, self).__init__()
+        super().__init__()
         assert n_feat % n_head == 0
         # We assume d_v always equals d_k
         self.d_k = n_feat // n_head
@@ -591,6 +686,48 @@
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         return self.forward_attention(v_h, scores, None), cache
 
+class MultiHeadedAttentionCrossAttExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_q = model.linear_q
+        self.linear_k_v = model.linear_k_v
+        self.linear_out = model.linear_out
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, memory, memory_mask):
+        q, k, v = self.forward_qkv(x, memory)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, memory_mask)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x, memory):
+        q = self.linear_q(x)
+
+        k_v = self.linear_k_v(memory)
+        k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
+        q = self.transpose_for_scores(q)
+        k = self.transpose_for_scores(k)
+        v = self.transpose_for_scores(v)
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
 
 class MultiHeadSelfAttention(nn.Module):
     """Multi-Head Attention layer.
diff --git a/funasr/models/sanm/attention_export.py b/funasr/models/sanm/attention_export.py
new file mode 100644
index 0000000..435dd1e
--- /dev/null
+++ b/funasr/models/sanm/attention_export.py
@@ -0,0 +1,114 @@
+import os
+import math
+
+import torch
+import torch.nn as nn
+
+
+
+
+
+
+
+
+
+
+class OnnxMultiHeadedAttention(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_q = model.linear_q
+        self.linear_k = model.linear_k
+        self.linear_v = model.linear_v
+        self.linear_out = model.linear_out
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+    
+    def forward(self, query, key, value, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, query, key, value):
+        q = self.linear_q(query)
+        k = self.linear_k(key)
+        v = self.linear_v(value)
+        q = self.transpose_for_scores(q)
+        k = self.transpose_for_scores(k)
+        v = self.transpose_for_scores(v)
+        return q, k, v
+    
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
+    def __init__(self, model):
+        super().__init__(model)
+        self.linear_pos = model.linear_pos
+        self.pos_bias_u = model.pos_bias_u
+        self.pos_bias_v = model.pos_bias_v
+    
+    def forward(self, query, key, value, pos_emb, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, time1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+    def rel_shift(self, x):
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)[
+            :, :, :, : x.size(-1) // 2 + 1
+        ]  # only keep the positions from 0 to time2
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+        
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index 640be05..f563a9b 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -3,7 +3,6 @@
 
 def export_onnx(model,
                 data_in=None,
-				type: str = "onnx",
 				quantize: bool = False,
 				fallback_num: int = 5,
 				calib_num: int = 100,
@@ -19,7 +18,6 @@
 		m.eval()
 		_onnx(m,
 		      data_in=data_in,
-		      type=type,
 		      quantize=quantize,
 		      fallback_num=fallback_num,
 		      calib_num=calib_num,
diff --git a/setup.py b/setup.py
index e3d1c2e..b41148c 100644
--- a/setup.py
+++ b/setup.py
@@ -140,5 +140,6 @@
     ],
     entry_points={"console_scripts": [
         "funasr = funasr.bin.inference:main_hydra",
+        "funasr-export = funasr.bin.export:main_hydra",
     ]},
 )

--
Gitblit v1.9.1