From 9d48230c4f8f25bf88c5d6105f97370a36c9cf43 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 10:48:50 +0800
Subject: [PATCH] export onnx (#1457)
---
funasr/models/paraformer/model.py | 12
examples/industrial_data_pretraining/bicif_paraformer/export.sh | 28 +
funasr/bin/export.py | 3
funasr/models/ct_transformer/model.py | 54 +++
examples/industrial_data_pretraining/paraformer/export.py | 17
setup.py | 1
examples/industrial_data_pretraining/ct_transformer/export.py | 26 +
examples/industrial_data_pretraining/bicif_paraformer/export.py | 26 +
funasr/utils/export_utils.py | 2
examples/industrial_data_pretraining/paraformer/README_zh.md | 14
funasr/models/bicif_paraformer/model.py | 88 ++++
funasr/models/sanm/attention_export.py | 114 ++++++
examples/industrial_data_pretraining/paraformer/export.sh | 14
funasr/models/ct_transformer_streaming/encoder.py | 105 +++++
examples/industrial_data_pretraining/ct_transformer/export.sh | 28 +
funasr/models/sanm/attention.py | 141 +++++++
examples/industrial_data_pretraining/ct_transformer_streaming/export.py | 26 +
funasr/models/fsmn_vad_streaming/model.py | 7
funasr/auto/auto_model.py | 4
funasr/models/paraformer/cif_predictor.py | 2
funasr/models/bicif_paraformer/cif_predictor.py | 163 +++++++++
funasr/models/fsmn_vad_streaming/encoder.py | 60 ++
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py | 16
funasr/models/paraformer/decoder.py | 5
funasr/models/ct_transformer_streaming/model.py | 65 +++
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh | 18
examples/industrial_data_pretraining/ct_transformer_streaming/export.sh | 28 +
27 files changed, 1,029 insertions(+), 38 deletions(-)
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
new file mode 100644
index 0000000..78e7295
--- /dev/null
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+
+# # method2, inference from local path
+# from funasr import AutoModel
+#
+# wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+#
+# model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+#
+# res = model.export(input=wav_file, type="onnx", quantize=False)
+# print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.sh b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
new file mode 100644
index 0000000..bc20a90
--- /dev/null
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.sh
@@ -0,0 +1,28 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_revision="v2.0.4"
+
+python -m funasr.bin.export \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.py b/examples/industrial_data_pretraining/ct_transformer/export.py
new file mode 100644
index 0000000..3321525
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer/export.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
+
+model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+ model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+
+# method2, inference from local path
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer/export.sh b/examples/industrial_data_pretraining/ct_transformer/export.sh
new file mode 100644
index 0000000..f7849a1
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer/export.sh
@@ -0,0 +1,28 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model_revision="v2.0.4"
+
+python -m funasr.bin.export \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/export.py b/examples/industrial_data_pretraining/ct_transformer_streaming/export.py
new file mode 100644
index 0000000..4e50501
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/export.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt"
+
+model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
+ model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+
+# method2, inference from local path
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/export.sh b/examples/industrial_data_pretraining/ct_transformer_streaming/export.sh
new file mode 100644
index 0000000..118afbb
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/export.sh
@@ -0,0 +1,28 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model_revision="v2.0.4"
+
+python -m funasr.bin.export \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
index d259104..2e09523 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
@@ -3,10 +3,24 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
-model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4", export_model=True)
+model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
+
+# method2, inference from local path
+
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
index 0bb4617..911a1a1 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
@@ -6,14 +6,24 @@
export HYDRA_FULL_ERROR=1
-model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.4"
-python funasr/bin/export.py \
+python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
-++device="cpu" \
-++debug=false
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/README_zh.md b/examples/industrial_data_pretraining/paraformer/README_zh.md
index f9ab616..38a4455 100644
--- a/examples/industrial_data_pretraining/paraformer/README_zh.md
+++ b/examples/industrial_data_pretraining/paraformer/README_zh.md
@@ -78,4 +78,18 @@
```bash
tensorboard --logdir /xxxx/FunASR/examples/industrial_data_pretraining/paraformer/outputs/log/tensorboard
+```
+
+
+## 瀵煎嚭onnx
+
+```python
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ model_revision="v2.0.4")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
```
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
index 613c3a9..43b3c18 100644
--- a/examples/industrial_data_pretraining/paraformer/export.py
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -3,11 +3,26 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- model_revision="v2.0.4", export_model=True)
+ model_revision="v2.0.4")
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
+
+
+# method2, inference from local path
+from funasr import AutoModel
+
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.sh b/examples/industrial_data_pretraining/paraformer/export.sh
index 9f45a5a..c67dca8 100644
--- a/examples/industrial_data_pretraining/paraformer/export.sh
+++ b/examples/industrial_data_pretraining/paraformer/export.sh
@@ -8,11 +8,23 @@
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
-python funasr/bin/export.py \
+
+python -m funasr.bin.export \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
+++device="cpu"
+
+
+# method2, inference from local path
+model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+
+python -m funasr.bin.export \
+++model=${model} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
++device="cpu" \
++debug=false
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index d7b6cb9..c4bab03 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -100,9 +100,7 @@
def __init__(self, **kwargs):
if not kwargs.get("disable_log", True):
tables.print()
- if kwargs.get("export_model", False):
- os.environ['EXPORTING_MODEL'] = 'TRUE'
-
+
model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
index 68acc17..7d47664 100644
--- a/funasr/bin/export.py
+++ b/funasr/bin/export.py
@@ -24,8 +24,9 @@
if kwargs.get("debug", False):
import pdb; pdb.set_trace()
+
+ model = AutoModel(**kwargs)
- model = AutoModel(export_model=True, **kwargs)
res = model.export(input=kwargs.get("input", None),
type=kwargs.get("type", "onnx"),
quantize=kwargs.get("quantize", False),
diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py
index e7b3ba9..2cdbc16 100644
--- a/funasr/models/bicif_paraformer/cif_predictor.py
+++ b/funasr/models/bicif_paraformer/cif_predictor.py
@@ -336,3 +336,166 @@
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
+
+@tables.register("predictor_classes", "CifPredictorV3Export")
+class CifPredictorV3Export(torch.nn.Module):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+
+ self.pad = model.pad
+ self.cif_conv1d = model.cif_conv1d
+ self.cif_output = model.cif_output
+ self.threshold = model.threshold
+ self.smooth_factor = model.smooth_factor
+ self.noise_threshold = model.noise_threshold
+ self.tail_threshold = model.tail_threshold
+
+ self.upsample_times = model.upsample_times
+ self.upsample_cnn = model.upsample_cnn
+ self.blstm = model.blstm
+ self.cif_output2 = model.cif_output2
+ self.smooth_factor2 = model.smooth_factor2
+ self.noise_threshold2 = model.noise_threshold2
+
+ def forward(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ alphas = alphas.squeeze(-1)
+ token_num = alphas.sum(-1)
+
+ mask = mask.squeeze(-1)
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+ acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+
+ return acoustic_embeds, token_num, alphas, cif_peak
+
+ def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
+ h = hidden
+ b = hidden.shape[0]
+ context = h.transpose(1, 2)
+
+ # generate alphas2
+ _output = context
+ output2 = self.upsample_cnn(_output)
+ output2 = output2.transpose(1, 2)
+ output2, (_, _) = self.blstm(output2)
+ alphas2 = torch.sigmoid(self.cif_output2(output2))
+ alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
+
+ mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
+ mask = mask.unsqueeze(-1)
+ alphas2 = alphas2 * mask
+ alphas2 = alphas2.squeeze(-1)
+ _token_num = alphas2.sum(-1)
+ alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
+ # upsampled alphas and cif_peak
+ us_alphas = alphas2
+ us_cif_peak = cif_wo_hidden_export(us_alphas, self.threshold - 1e-4)
+ return us_alphas, us_cif_peak
+
+ def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+ b, t, d = hidden.size()
+ tail_threshold = self.tail_threshold
+
+ zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+ ones_t = torch.ones_like(zeros_t)
+
+ mask_1 = torch.cat([mask, zeros_t], dim=1)
+ mask_2 = torch.cat([ones_t, mask], dim=1)
+ mask = mask_2 - mask_1
+ tail_threshold = mask * tail_threshold
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
+
+ zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+ hidden = torch.cat([hidden, zeros], dim=1)
+ token_num = alphas.sum(dim=-1)
+ token_num_floor = torch.floor(token_num)
+
+ return hidden, alphas, token_num_floor
+
+
+@torch.jit.script
+def cif_export(hidden, alphas, threshold: float):
+ batch_size, len_time, hidden_size = hidden.size()
+ threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+
+ # loop varss
+ integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
+ frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
+ # intermediate vars along time
+ list_fires = []
+ list_frames = []
+
+ for t in range(len_time):
+ alpha = alphas[:, t]
+ distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
+
+ integrate += alpha
+ list_fires.append(integrate)
+
+ fire_place = integrate >= threshold
+ integrate = torch.where(fire_place,
+ integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
+ integrate)
+ cur = torch.where(fire_place,
+ distribution_completion,
+ alpha)
+ remainds = alpha - cur
+
+ frame += cur[:, None] * hidden[:, t, :]
+ list_frames.append(frame)
+ frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+ remainds[:, None] * hidden[:, t, :],
+ frame)
+
+ fires = torch.stack(list_fires, 1)
+ frames = torch.stack(list_frames, 1)
+
+ fire_idxs = fires >= threshold
+ frame_fires = torch.zeros_like(hidden)
+ max_label_len = frames[0, fire_idxs[0]].size(0)
+ for b in range(batch_size):
+ frame_fire = frames[b, fire_idxs[b]]
+ frame_len = frame_fire.size(0)
+ frame_fires[b, :frame_len, :] = frame_fire
+
+ if frame_len >= max_label_len:
+ max_label_len = frame_len
+ frame_fires = frame_fires[:, :max_label_len, :]
+ return frame_fires, fires
+
+
+@torch.jit.script
+def cif_wo_hidden_export(alphas, threshold: float):
+ batch_size, len_time = alphas.size()
+
+ # loop varss
+ integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
+ # intermediate vars along time
+ list_fires = []
+
+ for t in range(len_time):
+ alpha = alphas[:, t]
+
+ integrate += alpha
+ list_fires.append(integrate)
+
+ fire_place = integrate >= threshold
+ integrate = torch.where(fire_place,
+ integrate - torch.ones([batch_size], device=alphas.device) * threshold,
+ integrate)
+
+ fires = torch.stack(list_fires, 1)
+ return fires
\ No newline at end of file
diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index 696cd56..eb7318b 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -341,4 +341,90 @@
result_i = {"key": key[i], "token_int": token_int}
results.append(result_i)
- return results, meta_data
\ No newline at end of file
+ return results, meta_data
+
+ def export(
+ self,
+ max_seq_len=512,
+ **kwargs,
+ ):
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+ self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+ self.decoder = decoder_class(self.decoder, onnx=is_onnx)
+
+ from funasr.utils.torch_function import MakePadMask
+ from funasr.utils.torch_function import sequence_mask
+
+ if is_onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ self.forward = self._export_forward
+
+ return self
+
+ def _export_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ # batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.round().type(torch.int32)
+
+ decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+
+ # get predicted timestamps
+ us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
+
+ return decoder_out, pre_token_length, us_alphas, us_cif_peak
+
+ def export_dummy_inputs(self):
+ speech = torch.randn(2, 30, 560)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+ def export_input_names(self):
+ return ['speech', 'speech_lengths']
+
+ def export_output_names(self):
+ return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
+
+ def export_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ 'us_alphas': {
+ 0: 'batch_size',
+ 1: 'alphas_length'
+ },
+ 'us_cif_peak': {
+ 0: 'batch_size',
+ 1: 'alphas_length'
+ },
+ }
+
+ def export_name(self, ):
+ return "model.onnx"
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 45f5746..31b8c27 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -365,3 +365,57 @@
results.append(result_i)
return results, meta_data
+ def export(
+ self,
+ **kwargs,
+ ):
+
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+ self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+
+ self.forward = self._export_forward
+
+ return self
+
+ def _export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
+ """Compute loss value from buffer sequences.
+
+ Args:
+ input (torch.Tensor): Input ids. (batch, len)
+ hidden (torch.Tensor): Target ids. (batch, len)
+
+ """
+ x = self.embed(inputs)
+ h, _ = self.encoder(x, text_lengths)
+ y = self.decoder(h)
+ return y
+
+ def export_dummy_inputs(self):
+ length = 120
+ text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
+ text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+ return (text_indexes, text_lengths)
+
+ def export_input_names(self):
+ return ['inputs', 'text_lengths']
+
+ def export_output_names(self):
+ return ['logits']
+
+ def export_dynamic_axes(self):
+ return {
+ 'inputs': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'text_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
+ def export_name(self):
+ return "model.onnx"
\ No newline at end of file
diff --git a/funasr/models/ct_transformer_streaming/encoder.py b/funasr/models/ct_transformer_streaming/encoder.py
index 95e2a4b..badf5f6 100644
--- a/funasr/models/ct_transformer_streaming/encoder.py
+++ b/funasr/models/ct_transformer_streaming/encoder.py
@@ -371,3 +371,108 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+
+
+class EncoderLayerSANMExport(torch.nn.Module):
+ def __init__(
+ self,
+ model,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2
+ self.in_size = model.in_size
+ self.size = model.size
+
+ def forward(self, x, mask):
+
+ residual = x
+ x = self.norm1(x)
+ x = self.self_attn(x, mask)
+ if self.in_size == self.size:
+ x = x + residual
+ residual = x
+ x = self.norm2(x)
+ x = self.feed_forward(x)
+ x = x + residual
+
+ return x, mask
+
+@tables.register("encoder_classes", "SANMVadEncoderExport")
+class SANMVadEncoderExport(torch.nn.Module):
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='encoder',
+ onnx: bool = True,
+ ):
+ super().__init__()
+ self.embed = model.embed
+ self.model = model
+ self._output_size = model._output_size
+
+ from funasr.utils.torch_function import MakePadMask
+ from funasr.utils.torch_function import sequence_mask
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
+
+ if hasattr(model, 'encoders0'):
+ for i, d in enumerate(self.model.encoders0):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
+ d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+ self.model.encoders0[i] = EncoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.encoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
+ d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+ self.model.encoders[i] = EncoderLayerSANMExport(d)
+
+
+ def prepare_mask(self, mask, sub_masks):
+ mask_3d_btd = mask[:, :, None]
+ mask_4d_bhlt = (1 - sub_masks) * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ vad_masks: torch.Tensor,
+ sub_masks: torch.Tensor,
+ ):
+ speech = speech * self._output_size ** 0.5
+ mask = self.make_pad_mask(speech_lengths)
+ vad_masks = self.prepare_mask(mask, vad_masks)
+ mask = self.prepare_mask(mask, sub_masks)
+
+ if self.embed is None:
+ xs_pad = speech
+ else:
+ xs_pad = self.embed(speech)
+
+ encoder_outs = self.model.encoders0(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ # encoder_outs = self.model.encoders(xs_pad, mask)
+ for layer_idx, encoder_layer in enumerate(self.model.encoders):
+ if layer_idx == len(self.model.encoders) - 1:
+ mask = vad_masks
+ encoder_outs = encoder_layer(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ xs_pad = self.model.after_norm(xs_pad)
+
+ return xs_pad, speech_lengths
+
+ def get_output_size(self):
+ return self.model.encoders[0].size
\ No newline at end of file
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index 217767a..a9b2efb 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -173,3 +173,68 @@
return results, meta_data
+ def export(
+ self,
+ **kwargs,
+ ):
+
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+
+ self.forward = self._export_forward
+
+ return self
+
+ def _export_forward(self, inputs: torch.Tensor,
+ text_lengths: torch.Tensor,
+ vad_indexes: torch.Tensor,
+ sub_masks: torch.Tensor,
+ ):
+ """Compute loss value from buffer sequences.
+
+ Args:
+ input (torch.Tensor): Input ids. (batch, len)
+ hidden (torch.Tensor): Target ids. (batch, len)
+
+ """
+ x = self.embed(inputs)
+ # mask = self._target_mask(input)
+ h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
+ y = self.decoder(h)
+ return y
+
+ def export_dummy_inputs(self):
+ length = 120
+ text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
+ text_lengths = torch.tensor([length], dtype=torch.int32)
+ vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
+ sub_masks = torch.ones(length, length, dtype=torch.float32)
+ sub_masks = torch.tril(sub_masks).type(torch.float32)
+ return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
+
+ def export_input_names(self):
+ return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
+
+ def export_output_names(self):
+ return ['logits']
+
+ def export_dynamic_axes(self):
+ return {
+ 'inputs': {
+ 1: 'feats_length'
+ },
+ 'vad_masks': {
+ 2: 'feats_length1',
+ 3: 'feats_length2'
+ },
+ 'sub_masks': {
+ 2: 'feats_length1',
+ 3: 'feats_length2'
+ },
+ 'logits': {
+ 1: 'logits_length'
+ },
+ }
+ def export_name(self):
+ return "model.onnx"
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
index e7c0e8b..bc51a6f 100755
--- a/funasr/models/fsmn_vad_streaming/encoder.py
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -194,7 +194,7 @@
output_affine_dim: int,
output_dim: int
):
- super(FSMN, self).__init__()
+ super().__init__()
self.input_dim = input_dim
self.input_affine_dim = input_affine_dim
@@ -213,12 +213,6 @@
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
self.softmax = nn.Softmax(dim=-1)
- # export onnx or torchscripts
- if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE':
- for i, d in enumerate(self.fsmn):
- if isinstance(d, BasicBlock):
- self.fsmn[i] = BasicBlock_export(d)
-
def fuse_modules(self):
pass
@@ -244,10 +238,49 @@
return x7
- def export_forward(
- self,
- input: torch.Tensor,
- *args,
+
+@tables.register("encoder_classes", "FSMNExport")
+class FSMNExport(nn.Module):
+ def __init__(
+ self, model, **kwargs,
+ ):
+ super().__init__()
+
+ # self.input_dim = input_dim
+ # self.input_affine_dim = input_affine_dim
+ # self.fsmn_layers = fsmn_layers
+ # self.linear_dim = linear_dim
+ # self.proj_dim = proj_dim
+ # self.output_affine_dim = output_affine_dim
+ # self.output_dim = output_dim
+ #
+ # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+ # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+ # self.relu = RectifiedLinear(linear_dim, linear_dim)
+ # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+ # range(fsmn_layers)])
+ # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+ # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+ # self.softmax = nn.Softmax(dim=-1)
+ self.in_linear1 = model.in_linear1
+ self.in_linear2 = model.in_linear2
+ self.relu = model.relu
+ # self.fsmn = model.fsmn
+ self.out_linear1 = model.out_linear1
+ self.out_linear2 = model.out_linear2
+ self.softmax = model.softmax
+ self.fsmn = model.fsmn
+ for i, d in enumerate(model.fsmn):
+ if isinstance(d, BasicBlock):
+ self.fsmn[i] = BasicBlock_export(d)
+
+ def fuse_modules(self):
+ pass
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ *args,
):
"""
Args:
@@ -255,7 +288,7 @@
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
-
+
x = self.in_linear1(input)
x = self.in_linear2(x)
x = self.relu(x)
@@ -268,9 +301,10 @@
x = self.out_linear1(x)
x = self.out_linear2(x)
x = self.softmax(x)
-
+
return x, out_caches
+
'''
one deep fsmn layer
dimproj: projection dimension, input and output dimension of memory blocks
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 9649ee8..c3063b0 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -644,12 +644,17 @@
return results, meta_data
def export(self, **kwargs):
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+ self.encoder = encoder_class(self.encoder, onnx=is_onnx)
self.forward = self._export_forward
return self
def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
- scores, out_caches = self.encoder.export_forward(feats, *args)
+
+ scores, out_caches = self.encoder(feats, *args)
+
return scores, out_caches
def export_dummy_inputs(self, data_in=None, frame=30):
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index ddcfb5a..d538e21 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -376,7 +376,7 @@
return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2Export")
-class CifPredictorV2(torch.nn.Module):
+class CifPredictorV2Export(torch.nn.Module):
def __init__(self, model, **kwargs):
super().__init__()
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index ce018f4..572a34a 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -635,8 +635,9 @@
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
- from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
- from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
for i, d in enumerate(self.model.decoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 8739ed5..586d72d 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -554,21 +554,23 @@
max_seq_len=512,
**kwargs,
):
- onnx = kwargs.get("onnx", True)
+
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
- self.encoder = encoder_class(self.encoder, onnx=onnx)
+ self.encoder = encoder_class(self.encoder, onnx=is_onnx)
predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
- self.predictor = predictor_class(self.predictor, onnx=onnx)
+ self.predictor = predictor_class(self.predictor, onnx=is_onnx)
decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
- self.decoder = decoder_class(self.decoder, onnx=onnx)
+ self.decoder = decoder_class(self.decoder, onnx=is_onnx)
from funasr.utils.torch_function import MakePadMask
from funasr.utils.torch_function import sequence_mask
- if onnx:
+
+ if is_onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 09a1f07..c3a2f94 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -17,6 +17,24 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import funasr.models.lora.layers as lora
+
+def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
+ x = x * mask
+ x = x.transpose(1, 2)
+ if cache is None:
+ x = pad_fn(x)
+ else:
+ x = torch.cat((cache, x), dim=2)
+ cache = x[:, :, -(kernel_size-1):]
+ return x, cache
+
+
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
+ import torch.fx
+ torch.fx.wrap('preprocess_for_attn')
+
+
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -362,6 +380,65 @@
return self.linear_out(context_layer) # (batch, time1, d_model)
+class MultiHeadedAttentionSANMExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_out = model.linear_out
+ self.linear_q_k_v = model.linear_q_k_v
+ self.fsmn_block = model.fsmn_block
+ self.pad_fn = model.pad_fn
+
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, x, mask):
+ mask_3d_btd, mask_4d_bhlt = mask
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+ q_h = q_h * self.d_k**(-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+ return att_outs + fsmn_memory
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, x):
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = self.transpose_for_scores(q)
+ k_h = self.transpose_for_scores(k)
+ v_h = self.transpose_for_scores(v)
+ return q_h, k_h, v_h, v
+
+ def forward_fsmn(self, inputs, mask):
+ # b, t, d = inputs.size()
+ # mask = torch.reshape(mask, (b, -1, 1))
+ inputs = inputs * mask
+ x = inputs.transpose(1, 2)
+ x = self.pad_fn(x)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ x = x + inputs
+ x = x * mask
+ return x
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.
@@ -375,7 +452,7 @@
def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
"""Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionSANMDecoder, self).__init__()
+ super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
@@ -440,6 +517,24 @@
x = x * mask
return x, cache
+class MultiHeadedAttentionSANMDecoderExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.fsmn_block = model.fsmn_block
+ self.pad_fn = model.pad_fn
+ self.kernel_size = model.kernel_size
+ self.attn = None
+
+ def forward(self, inputs, mask, cache=None):
+ x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+
+ x = x + inputs
+ x = x * mask
+ return x, cache
+
+
class MultiHeadedAttentionCrossAtt(nn.Module):
"""Multi-Head Attention layer.
@@ -452,7 +547,7 @@
def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
"""Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionCrossAtt, self).__init__()
+ super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
@@ -591,6 +686,48 @@
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
return self.forward_attention(v_h, scores, None), cache
+class MultiHeadedAttentionCrossAttExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_q = model.linear_q
+ self.linear_k_v = model.linear_k_v
+ self.linear_out = model.linear_out
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, x, memory, memory_mask):
+ q, k, v = self.forward_qkv(x, memory)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, memory_mask)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, x, memory):
+ q = self.linear_q(x)
+
+ k_v = self.linear_k_v(memory)
+ k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
+ q = self.transpose_for_scores(q)
+ k = self.transpose_for_scores(k)
+ v = self.transpose_for_scores(v)
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
class MultiHeadSelfAttention(nn.Module):
"""Multi-Head Attention layer.
diff --git a/funasr/models/sanm/attention_export.py b/funasr/models/sanm/attention_export.py
new file mode 100644
index 0000000..435dd1e
--- /dev/null
+++ b/funasr/models/sanm/attention_export.py
@@ -0,0 +1,114 @@
+import os
+import math
+
+import torch
+import torch.nn as nn
+
+
+
+
+
+
+
+
+
+
+class OnnxMultiHeadedAttention(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_q = model.linear_q
+ self.linear_k = model.linear_k
+ self.linear_v = model.linear_v
+ self.linear_out = model.linear_out
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, query, key, value, mask):
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, query, key, value):
+ q = self.linear_q(query)
+ k = self.linear_k(key)
+ v = self.linear_v(value)
+ q = self.transpose_for_scores(q)
+ k = self.transpose_for_scores(k)
+ v = self.transpose_for_scores(v)
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
+
+class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
+ def __init__(self, model):
+ super().__init__(model)
+ self.linear_pos = model.linear_pos
+ self.pos_bias_u = model.pos_bias_u
+ self.pos_bias_v = model.pos_bias_v
+
+ def forward(self, query, key, value, pos_emb, mask):
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time1)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
+
+ def rel_shift(self, x):
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+ return x
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index 640be05..f563a9b 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -3,7 +3,6 @@
def export_onnx(model,
data_in=None,
- type: str = "onnx",
quantize: bool = False,
fallback_num: int = 5,
calib_num: int = 100,
@@ -19,7 +18,6 @@
m.eval()
_onnx(m,
data_in=data_in,
- type=type,
quantize=quantize,
fallback_num=fallback_num,
calib_num=calib_num,
diff --git a/setup.py b/setup.py
index e3d1c2e..b41148c 100644
--- a/setup.py
+++ b/setup.py
@@ -140,5 +140,6 @@
],
entry_points={"console_scripts": [
"funasr = funasr.bin.inference:main_hydra",
+ "funasr-export = funasr.bin.export:main_hydra",
]},
)
--
Gitblit v1.9.1