zhifu gao
2024-03-11 9d48230c4f8f25bf88c5d6105f97370a36c9cf43
export onnx (#1457)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx
20个文件已修改
7个文件已添加
1053 ■■■■■ 已修改文件
examples/industrial_data_pretraining/bicif_paraformer/export.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/export.sh 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer/export.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer/export.sh 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer_streaming/export.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/ct_transformer_streaming/export.sh 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/README_zh.md 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.sh 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/export.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bicif_paraformer/cif_predictor.py 163 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bicif_paraformer/model.py 86 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/model.py 54 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer_streaming/encoder.py 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer_streaming/model.py 65 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/encoder.py 50 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/model.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/cif_predictor.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/decoder.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/attention.py 141 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/attention_export.py 114 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/export_utils.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/export.py
New file
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
# # method2, inference from local path
# from funasr import AutoModel
#
# wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
#
# model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
#
# res = model.export(input=wav_file, type="onnx", quantize=False)
# print(res)
examples/industrial_data_pretraining/bicif_paraformer/export.sh
New file
@@ -0,0 +1,28 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
export HYDRA_FULL_ERROR=1
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
examples/industrial_data_pretraining/ct_transformer/export.py
New file
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/ct_transformer/export.sh
New file
@@ -0,0 +1,28 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
export HYDRA_FULL_ERROR=1
model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
model_revision="v2.0.4"
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
examples/industrial_data_pretraining/ct_transformer_streaming/export.py
New file
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/ct_transformer_streaming/export.sh
New file
@@ -0,0 +1,28 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
export HYDRA_FULL_ERROR=1
model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model_revision="v2.0.4"
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
@@ -3,10 +3,24 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4", export_model=True)
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
@@ -6,14 +6,24 @@
export HYDRA_FULL_ERROR=1
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.4"
python funasr/bin/export.py \
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu" \
++debug=false
++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
examples/industrial_data_pretraining/paraformer/README_zh.md
@@ -79,3 +79,17 @@
```bash
tensorboard --logdir /xxxx/FunASR/examples/industrial_data_pretraining/paraformer/outputs/log/tensorboard
```
## 导出onnx
```python
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
```
examples/industrial_data_pretraining/paraformer/export.py
@@ -3,11 +3,26 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4", export_model=True)
                  model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/paraformer/export.sh
@@ -8,11 +8,23 @@
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
python funasr/bin/export.py \
python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu"
# method2, inference from local path
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
python -m funasr.bin.export \
++model=${model} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu" \
++debug=false
funasr/auto/auto_model.py
@@ -100,8 +100,6 @@
    def __init__(self, **kwargs):
        if not kwargs.get("disable_log", True):
            tables.print()
        if kwargs.get("export_model", False):
            os.environ['EXPORTING_MODEL'] = 'TRUE'
            
        model, kwargs = self.build_model(**kwargs)
        
funasr/bin/export.py
@@ -25,7 +25,8 @@
        import pdb; pdb.set_trace()
    
    model = AutoModel(export_model=True, **kwargs)
    model = AutoModel(**kwargs)
    res = model.export(input=kwargs.get("input", None),
                       type=kwargs.get("type", "onnx"),
                       quantize=kwargs.get("quantize", False),
funasr/models/bicif_paraformer/cif_predictor.py
@@ -336,3 +336,166 @@
        predictor_alignments = index_div_bool_zeros_count_tile_out
        predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
        return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV3Export")
class CifPredictorV3Export(torch.nn.Module):
    def __init__(self, model, **kwargs):
        super().__init__()
        self.pad = model.pad
        self.cif_conv1d = model.cif_conv1d
        self.cif_output = model.cif_output
        self.threshold = model.threshold
        self.smooth_factor = model.smooth_factor
        self.noise_threshold = model.noise_threshold
        self.tail_threshold = model.tail_threshold
        self.upsample_times = model.upsample_times
        self.upsample_cnn = model.upsample_cnn
        self.blstm = model.blstm
        self.cif_output2 = model.cif_output2
        self.smooth_factor2 = model.smooth_factor2
        self.noise_threshold2 = model.noise_threshold2
    def forward(self, hidden: torch.Tensor,
                mask: torch.Tensor,
                ):
        h = hidden
        context = h.transpose(1, 2)
        queries = self.pad(context)
        output = torch.relu(self.cif_conv1d(queries))
        output = output.transpose(1, 2)
        output = self.cif_output(output)
        alphas = torch.sigmoid(output)
        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
        mask = mask.transpose(-1, -2).float()
        alphas = alphas * mask
        alphas = alphas.squeeze(-1)
        token_num = alphas.sum(-1)
        mask = mask.squeeze(-1)
        hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
        acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
        return acoustic_embeds, token_num, alphas, cif_peak
    def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
        h = hidden
        b = hidden.shape[0]
        context = h.transpose(1, 2)
        # generate alphas2
        _output = context
        output2 = self.upsample_cnn(_output)
        output2 = output2.transpose(1, 2)
        output2, (_, _) = self.blstm(output2)
        alphas2 = torch.sigmoid(self.cif_output2(output2))
        alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
        mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
        mask = mask.unsqueeze(-1)
        alphas2 = alphas2 * mask
        alphas2 = alphas2.squeeze(-1)
        _token_num = alphas2.sum(-1)
        alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
        # upsampled alphas and cif_peak
        us_alphas = alphas2
        us_cif_peak = cif_wo_hidden_export(us_alphas, self.threshold - 1e-4)
        return us_alphas, us_cif_peak
    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
        b, t, d = hidden.size()
        tail_threshold = self.tail_threshold
        zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
        ones_t = torch.ones_like(zeros_t)
        mask_1 = torch.cat([mask, zeros_t], dim=1)
        mask_2 = torch.cat([ones_t, mask], dim=1)
        mask = mask_2 - mask_1
        tail_threshold = mask * tail_threshold
        alphas = torch.cat([alphas, zeros_t], dim=1)
        alphas = torch.add(alphas, tail_threshold)
        zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
        hidden = torch.cat([hidden, zeros], dim=1)
        token_num = alphas.sum(dim=-1)
        token_num_floor = torch.floor(token_num)
        return hidden, alphas, token_num_floor
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
    batch_size, len_time, hidden_size = hidden.size()
    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
    # loop varss
    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
    frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
    # intermediate vars along time
    list_fires = []
    list_frames = []
    for t in range(len_time):
        alpha = alphas[:, t]
        distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
        integrate += alpha
        list_fires.append(integrate)
        fire_place = integrate >= threshold
        integrate = torch.where(fire_place,
                                integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
                                integrate)
        cur = torch.where(fire_place,
                          distribution_completion,
                          alpha)
        remainds = alpha - cur
        frame += cur[:, None] * hidden[:, t, :]
        list_frames.append(frame)
        frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
                            remainds[:, None] * hidden[:, t, :],
                            frame)
    fires = torch.stack(list_fires, 1)
    frames = torch.stack(list_frames, 1)
    fire_idxs = fires >= threshold
    frame_fires = torch.zeros_like(hidden)
    max_label_len = frames[0, fire_idxs[0]].size(0)
    for b in range(batch_size):
        frame_fire = frames[b, fire_idxs[b]]
        frame_len = frame_fire.size(0)
        frame_fires[b, :frame_len, :] = frame_fire
        if frame_len >= max_label_len:
            max_label_len = frame_len
    frame_fires = frame_fires[:, :max_label_len, :]
    return frame_fires, fires
@torch.jit.script
def cif_wo_hidden_export(alphas, threshold: float):
    batch_size, len_time = alphas.size()
    # loop varss
    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
    # intermediate vars along time
    list_fires = []
    for t in range(len_time):
        alpha = alphas[:, t]
        integrate += alpha
        list_fires.append(integrate)
        fire_place = integrate >= threshold
        integrate = torch.where(fire_place,
                                integrate - torch.ones([batch_size], device=alphas.device) * threshold,
                                integrate)
    fires = torch.stack(list_fires, 1)
    return fires
funasr/models/bicif_paraformer/model.py
@@ -342,3 +342,89 @@
                results.append(result_i)
        
        return results, meta_data
    def export(
        self,
        max_seq_len=512,
        **kwargs,
    ):
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        if is_onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.forward = self._export_forward
        return self
    def _export_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.round().type(torch.int32)
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # get predicted timestamps
        us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
        return decoder_out, pre_token_length, us_alphas, us_cif_peak
    def export_dummy_inputs(self):
        speech = torch.randn(2, 30, 560)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def export_input_names(self):
        return ['speech', 'speech_lengths']
    def export_output_names(self):
        return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
    def export_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
            'us_alphas': {
                0: 'batch_size',
                1: 'alphas_length'
            },
            'us_cif_peak': {
                0: 'batch_size',
                1: 'alphas_length'
            },
        }
    def export_name(self, ):
        return "model.onnx"
funasr/models/ct_transformer/model.py
@@ -365,3 +365,57 @@
        results.append(result_i)
        return results, meta_data
    def export(
        self,
        **kwargs,
    ):
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        self.forward = self._export_forward
        return self
    def _export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(inputs)
        h, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y
    def export_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
        return (text_indexes, text_lengths)
    def export_input_names(self):
        return ['inputs', 'text_lengths']
    def export_output_names(self):
        return ['logits']
    def export_dynamic_axes(self):
        return {
            'inputs': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'text_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
    def export_name(self):
        return "model.onnx"
funasr/models/ct_transformer_streaming/encoder.py
@@ -371,3 +371,108 @@
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
class EncoderLayerSANMExport(torch.nn.Module):
    def __init__(
        self,
        model,
    ):
        """Construct an EncoderLayer object."""
        super().__init__()
        self.self_attn = model.self_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.in_size = model.in_size
        self.size = model.size
    def forward(self, x, mask):
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, mask)
        if self.in_size == self.size:
            x = x + residual
        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x = x + residual
        return x, mask
@tables.register("encoder_classes", "SANMVadEncoderExport")
class SANMVadEncoderExport(torch.nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        feats_dim=560,
        model_name='encoder',
        onnx: bool = True,
    ):
        super().__init__()
        self.embed = model.embed
        self.model = model
        self._output_size = model._output_size
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
        if hasattr(model, 'encoders0'):
            for i, d in enumerate(self.model.encoders0):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
                self.model.encoders0[i] = EncoderLayerSANMExport(d)
        for i, d in enumerate(self.model.encoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
            self.model.encoders[i] = EncoderLayerSANMExport(d)
    def prepare_mask(self, mask, sub_masks):
        mask_3d_btd = mask[:, :, None]
        mask_4d_bhlt = (1 - sub_masks) * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                vad_masks: torch.Tensor,
                sub_masks: torch.Tensor,
                ):
        speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        vad_masks = self.prepare_mask(mask, vad_masks)
        mask = self.prepare_mask(mask, sub_masks)
        if self.embed is None:
            xs_pad = speech
        else:
            xs_pad = self.embed(speech)
        encoder_outs = self.model.encoders0(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        # encoder_outs = self.model.encoders(xs_pad, mask)
        for layer_idx, encoder_layer in enumerate(self.model.encoders):
            if layer_idx == len(self.model.encoders) - 1:
                mask = vad_masks
            encoder_outs = encoder_layer(xs_pad, mask)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
        xs_pad = self.model.after_norm(xs_pad)
        return xs_pad, speech_lengths
    def get_output_size(self):
        return self.model.encoders[0].size
funasr/models/ct_transformer_streaming/model.py
@@ -173,3 +173,68 @@
    
        return results, meta_data
    def export(
        self,
        **kwargs,
    ):
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        self.forward = self._export_forward
        return self
    def _export_forward(self, inputs: torch.Tensor,
                text_lengths: torch.Tensor,
                vad_indexes: torch.Tensor,
                sub_masks: torch.Tensor,
                ):
        """Compute loss value from buffer sequences.
        Args:
            input (torch.Tensor): Input ids. (batch, len)
            hidden (torch.Tensor): Target ids. (batch, len)
        """
        x = self.embed(inputs)
        # mask = self._target_mask(input)
        h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
        y = self.decoder(h)
        return y
    def export_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
        text_lengths = torch.tensor([length], dtype=torch.int32)
        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
        sub_masks = torch.ones(length, length, dtype=torch.float32)
        sub_masks = torch.tril(sub_masks).type(torch.float32)
        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
    def export_input_names(self):
        return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
    def export_output_names(self):
        return ['logits']
    def export_dynamic_axes(self):
        return {
            'inputs': {
                1: 'feats_length'
            },
            'vad_masks': {
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'sub_masks': {
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'logits': {
                1: 'logits_length'
            },
        }
    def export_name(self):
        return "model.onnx"
funasr/models/fsmn_vad_streaming/encoder.py
@@ -194,7 +194,7 @@
            output_affine_dim: int,
            output_dim: int
    ):
        super(FSMN, self).__init__()
        super().__init__()
        self.input_dim = input_dim
        self.input_affine_dim = input_affine_dim
@@ -212,12 +212,6 @@
        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)
        # export onnx or torchscripts
        if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE':
            for i, d in enumerate(self.fsmn):
                if isinstance(d, BasicBlock):
                    self.fsmn[i] = BasicBlock_export(d)
    def fuse_modules(self):
        pass
@@ -244,7 +238,46 @@
        return x7
    def export_forward(
@tables.register("encoder_classes", "FSMNExport")
class FSMNExport(nn.Module):
    def __init__(
        self, model, **kwargs,
    ):
        super().__init__()
        # self.input_dim = input_dim
        # self.input_affine_dim = input_affine_dim
        # self.fsmn_layers = fsmn_layers
        # self.linear_dim = linear_dim
        # self.proj_dim = proj_dim
        # self.output_affine_dim = output_affine_dim
        # self.output_dim = output_dim
        #
        # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
        # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
        # self.relu = RectifiedLinear(linear_dim, linear_dim)
        # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
        #                         range(fsmn_layers)])
        # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        # self.softmax = nn.Softmax(dim=-1)
        self.in_linear1 = model.in_linear1
        self.in_linear2 = model.in_linear2
        self.relu = model.relu
        # self.fsmn = model.fsmn
        self.out_linear1 = model.out_linear1
        self.out_linear2 = model.out_linear2
        self.softmax = model.softmax
        self.fsmn = model.fsmn
        for i, d in enumerate(model.fsmn):
            if isinstance(d, BasicBlock):
                self.fsmn[i] = BasicBlock_export(d)
    def fuse_modules(self):
        pass
    def forward(
            self,
            input: torch.Tensor,
            *args,
@@ -271,6 +304,7 @@
        return x, out_caches
'''
one deep fsmn layer
dimproj:                projection dimension, input and output dimension of memory blocks
funasr/models/fsmn_vad_streaming/model.py
@@ -644,12 +644,17 @@
        return results, meta_data
    
    def export(self, **kwargs):
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        self.forward = self._export_forward
        
        return self
        
    def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
        scores, out_caches = self.encoder.export_forward(feats, *args)
        scores, out_caches = self.encoder(feats, *args)
        return scores, out_caches
    
    def export_dummy_inputs(self, data_in=None, frame=30):
funasr/models/paraformer/cif_predictor.py
@@ -376,7 +376,7 @@
        return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2Export")
class CifPredictorV2(torch.nn.Module):
class CifPredictorV2Export(torch.nn.Module):
    def __init__(self, model, **kwargs):
        super().__init__()
        
funasr/models/paraformer/decoder.py
@@ -635,8 +635,9 @@
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        
        from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
funasr/models/paraformer/model.py
@@ -554,21 +554,23 @@
        max_seq_len=512,
        **kwargs,
    ):
        onnx = kwargs.get("onnx", True)
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        self.encoder = encoder_class(self.encoder, onnx=onnx)
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        
        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
        self.predictor = predictor_class(self.predictor, onnx=onnx)
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
        self.decoder = decoder_class(self.decoder, onnx=onnx)
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        
        if onnx:
        if is_onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
funasr/models/sanm/attention.py
@@ -17,6 +17,24 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import funasr.models.lora.layers as lora
def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
    x = x * mask
    x = x.transpose(1, 2)
    if cache is None:
        x = pad_fn(x)
    else:
        x = torch.cat((cache, x), dim=2)
        cache = x[:, :, -(kernel_size-1):]
    return x, cache
torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
if torch_version >= (1, 8):
    import torch.fx
    torch.fx.wrap('preprocess_for_attn')
class MultiHeadedAttention(nn.Module):
    """Multi-Head Attention layer.
@@ -362,6 +380,65 @@
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class MultiHeadedAttentionSANMExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_out = model.linear_out
        self.linear_q_k_v = model.linear_q_k_v
        self.fsmn_block = model.fsmn_block
        self.pad_fn = model.pad_fn
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, x, mask):
        mask_3d_btd, mask_4d_bhlt = mask
        q_h, k_h, v_h, v = self.forward_qkv(x)
        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
        q_h = q_h * self.d_k**(-0.5)
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
        return att_outs + fsmn_memory
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, x):
        q_k_v = self.linear_q_k_v(x)
        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
        q_h = self.transpose_for_scores(q)
        k_h = self.transpose_for_scores(k)
        v_h = self.transpose_for_scores(v)
        return q_h, k_h, v_h, v
    def forward_fsmn(self, inputs, mask):
        # b, t, d = inputs.size()
        # mask = torch.reshape(mask, (b, -1, 1))
        inputs = inputs * mask
        x = inputs.transpose(1, 2)
        x = self.pad_fn(x)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x = x + inputs
        x = x * mask
        return x
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class MultiHeadedAttentionSANMDecoder(nn.Module):
    """Multi-Head Attention layer.
@@ -375,7 +452,7 @@
    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
        """Construct an MultiHeadedAttention object."""
        super(MultiHeadedAttentionSANMDecoder, self).__init__()
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
@@ -440,6 +517,24 @@
            x = x * mask
        return x, cache
class MultiHeadedAttentionSANMDecoderExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.fsmn_block = model.fsmn_block
        self.pad_fn = model.pad_fn
        self.kernel_size = model.kernel_size
        self.attn = None
    def forward(self, inputs, mask, cache=None):
        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x = x + inputs
        x = x * mask
        return x, cache
class MultiHeadedAttentionCrossAtt(nn.Module):
    """Multi-Head Attention layer.
@@ -452,7 +547,7 @@
    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
        """Construct an MultiHeadedAttention object."""
        super(MultiHeadedAttentionCrossAtt, self).__init__()
        super().__init__()
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
@@ -591,6 +686,48 @@
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        return self.forward_attention(v_h, scores, None), cache
class MultiHeadedAttentionCrossAttExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_q = model.linear_q
        self.linear_k_v = model.linear_k_v
        self.linear_out = model.linear_out
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, x, memory, memory_mask):
        q, k, v = self.forward_qkv(x, memory)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, memory_mask)
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, x, memory):
        q = self.linear_q(x)
        k_v = self.linear_k_v(memory)
        k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)
        return q, k, v
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class MultiHeadSelfAttention(nn.Module):
    """Multi-Head Attention layer.
funasr/models/sanm/attention_export.py
New file
@@ -0,0 +1,114 @@
import os
import math
import torch
import torch.nn as nn
class OnnxMultiHeadedAttention(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_q = model.linear_q
        self.linear_k = model.linear_k
        self.linear_v = model.linear_v
        self.linear_out = model.linear_out
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, query, key, value, mask):
        q, k, v = self.forward_qkv(query, key, value)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, mask)
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, query, key, value):
        q = self.linear_q(query)
        k = self.linear_k(key)
        v = self.linear_v(value)
        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)
        return q, k, v
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
    def __init__(self, model):
        super().__init__(model)
        self.linear_pos = model.linear_pos
        self.pos_bias_u = model.pos_bias_u
        self.pos_bias_v = model.pos_bias_v
    def forward(self, query, key, value, pos_emb, mask):
        q, k, v = self.forward_qkv(query, key, value)
        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
        p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
        # (batch, head, time1, d_k)
        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
        # (batch, head, time1, d_k)
        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
        # compute attention score
        # first compute matrix a and matrix c
        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
        # (batch, head, time1, time2)
        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
        # compute matrix b and matrix d
        # (batch, head, time1, time1)
        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
        matrix_bd = self.rel_shift(matrix_bd)
        scores = (matrix_ac + matrix_bd) / math.sqrt(
            self.d_k
        )  # (batch, head, time1, time2)
        return self.forward_attention(v, scores, mask)
    def rel_shift(self, x):
        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=-1)
        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
        x = x_padded[:, :, 1:].view_as(x)[
            :, :, :, : x.size(-1) // 2 + 1
        ]  # only keep the positions from 0 to time2
        return x
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
funasr/utils/export_utils.py
@@ -3,7 +3,6 @@
def export_onnx(model,
                data_in=None,
                type: str = "onnx",
                quantize: bool = False,
                fallback_num: int = 5,
                calib_num: int = 100,
@@ -19,7 +18,6 @@
        m.eval()
        _onnx(m,
              data_in=data_in,
              type=type,
              quantize=quantize,
              fallback_num=fallback_num,
              calib_num=calib_num,
setup.py
@@ -140,5 +140,6 @@
    ],
    entry_points={"console_scripts": [
        "funasr = funasr.bin.inference:main_hydra",
        "funasr-export = funasr.bin.export:main_hydra",
    ]},
)