From 38de2af5bf9976d2f14f087d9a0d31991daf6783 Mon Sep 17 00:00:00 2001
From: Zhihao Du <neo.dzh@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:41:34 +0800
Subject: [PATCH] Merge branch 'main' into dev_dzh
---
funasr/export/models/modules/multihead_att.py | 28 +++++++++++++++++-----------
1 files changed, 17 insertions(+), 11 deletions(-)
diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py
index 7d685f5..0a56676 100644
--- a/funasr/export/models/modules/multihead_att.py
+++ b/funasr/export/models/modules/multihead_att.py
@@ -64,6 +64,21 @@
return self.linear_out(context_layer) # (batch, time1, d_model)
+def preprocess_for_attn(x, mask, cache, pad_fn):
+ x = x * mask
+ x = x.transpose(1, 2)
+ if cache is None:
+ x = pad_fn(x)
+ else:
+ x = torch.cat((cache[:, :, 1:], x), dim=2)
+ cache = x
+ return x, cache
+
+
+import torch.fx
+torch.fx.wrap('preprocess_for_attn')
+
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
def __init__(self, model):
super().__init__()
@@ -73,16 +88,7 @@
self.attn = None
def forward(self, inputs, mask, cache=None):
- # b, t, d = inputs.size()
- # mask = torch.reshape(mask, (b, -1, 1))
- inputs = inputs * mask
-
- x = inputs.transpose(1, 2)
- if cache is None:
- x = self.pad_fn(x)
- else:
- x = torch.cat((cache[:, :, 1:], x), dim=2)
- cache = x
+ x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
@@ -232,4 +238,4 @@
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)
-
\ No newline at end of file
+
--
Gitblit v1.9.1