From 9d48230c4f8f25bf88c5d6105f97370a36c9cf43 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 10:48:50 +0800
Subject: [PATCH] export onnx (#1457)
---
funasr/models/sanm/attention.py | 141 ++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 139 insertions(+), 2 deletions(-)
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 09a1f07..c3a2f94 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -17,6 +17,24 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask
import funasr.models.lora.layers as lora
+
+def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
+ x = x * mask
+ x = x.transpose(1, 2)
+ if cache is None:
+ x = pad_fn(x)
+ else:
+ x = torch.cat((cache, x), dim=2)
+ cache = x[:, :, -(kernel_size-1):]
+ return x, cache
+
+
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
+ import torch.fx
+ torch.fx.wrap('preprocess_for_attn')
+
+
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -362,6 +380,65 @@
return self.linear_out(context_layer) # (batch, time1, d_model)
+class MultiHeadedAttentionSANMExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_out = model.linear_out
+ self.linear_q_k_v = model.linear_q_k_v
+ self.fsmn_block = model.fsmn_block
+ self.pad_fn = model.pad_fn
+
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, x, mask):
+ mask_3d_btd, mask_4d_bhlt = mask
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+ q_h = q_h * self.d_k**(-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+ return att_outs + fsmn_memory
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, x):
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = self.transpose_for_scores(q)
+ k_h = self.transpose_for_scores(k)
+ v_h = self.transpose_for_scores(v)
+ return q_h, k_h, v_h, v
+
+ def forward_fsmn(self, inputs, mask):
+ # b, t, d = inputs.size()
+ # mask = torch.reshape(mask, (b, -1, 1))
+ inputs = inputs * mask
+ x = inputs.transpose(1, 2)
+ x = self.pad_fn(x)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ x = x + inputs
+ x = x * mask
+ return x
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.
@@ -375,7 +452,7 @@
def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
"""Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionSANMDecoder, self).__init__()
+ super().__init__()
self.dropout = nn.Dropout(p=dropout_rate)
@@ -440,6 +517,24 @@
x = x * mask
return x, cache
+class MultiHeadedAttentionSANMDecoderExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.fsmn_block = model.fsmn_block
+ self.pad_fn = model.pad_fn
+ self.kernel_size = model.kernel_size
+ self.attn = None
+
+ def forward(self, inputs, mask, cache=None):
+ x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+
+ x = x + inputs
+ x = x * mask
+ return x, cache
+
+
class MultiHeadedAttentionCrossAtt(nn.Module):
"""Multi-Head Attention layer.
@@ -452,7 +547,7 @@
def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
"""Construct an MultiHeadedAttention object."""
- super(MultiHeadedAttentionCrossAtt, self).__init__()
+ super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
@@ -591,6 +686,48 @@
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
return self.forward_attention(v_h, scores, None), cache
+class MultiHeadedAttentionCrossAttExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_q = model.linear_q
+ self.linear_k_v = model.linear_k_v
+ self.linear_out = model.linear_out
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, x, memory, memory_mask):
+ q, k, v = self.forward_qkv(x, memory)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, memory_mask)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, x, memory):
+ q = self.linear_q(x)
+
+ k_v = self.linear_k_v(memory)
+ k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
+ q = self.transpose_for_scores(q)
+ k = self.transpose_for_scores(k)
+ v = self.transpose_for_scores(v)
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
class MultiHeadSelfAttention(nn.Module):
"""Multi-Head Attention layer.
--
Gitblit v1.9.1