From 4137f5cf26e7c4b40853959cd2574edfde03aa60 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 07 四月 2023 21:03:34 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR into dev_dzh
---
funasr/export/models/modules/multihead_att.py | 30 +++++++++++++++++++-----------
1 files changed, 19 insertions(+), 11 deletions(-)
diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py
index 7d685f5..6fce851 100644
--- a/funasr/export/models/modules/multihead_att.py
+++ b/funasr/export/models/modules/multihead_att.py
@@ -64,6 +64,23 @@
return self.linear_out(context_layer) # (batch, time1, d_model)
+def preprocess_for_attn(x, mask, cache, pad_fn):
+ x = x * mask
+ x = x.transpose(1, 2)
+ if cache is None:
+ x = pad_fn(x)
+ else:
+ x = torch.cat((cache[:, :, 1:], x), dim=2)
+ cache = x
+ return x, cache
+
+
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
+ import torch.fx
+ torch.fx.wrap('preprocess_for_attn')
+
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
def __init__(self, model):
super().__init__()
@@ -73,16 +90,7 @@
self.attn = None
def forward(self, inputs, mask, cache=None):
- # b, t, d = inputs.size()
- # mask = torch.reshape(mask, (b, -1, 1))
- inputs = inputs * mask
-
- x = inputs.transpose(1, 2)
- if cache is None:
- x = self.pad_fn(x)
- else:
- x = torch.cat((cache[:, :, 1:], x), dim=2)
- cache = x
+ x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
@@ -232,4 +240,4 @@
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.linear_out(context_layer) # (batch, time1, d_model)
-
\ No newline at end of file
+
--
Gitblit v1.9.1