From 38de2af5bf9976d2f14f087d9a0d31991daf6783 Mon Sep 17 00:00:00 2001
From: Zhihao Du <neo.dzh@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:41:34 +0800
Subject: [PATCH] Merge branch 'main' into dev_dzh

---
 funasr/export/models/modules/multihead_att.py |   28 +++++++++++++++++-----------
 1 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py
index 7d685f5..0a56676 100644
--- a/funasr/export/models/modules/multihead_att.py
+++ b/funasr/export/models/modules/multihead_att.py
@@ -64,6 +64,21 @@
         return self.linear_out(context_layer)  # (batch, time1, d_model)
 
 
+def preprocess_for_attn(x, mask, cache, pad_fn):
+    x = x * mask
+    x = x.transpose(1, 2)
+    if cache is None:
+        x = pad_fn(x)
+    else:
+        x = torch.cat((cache[:, :, 1:], x), dim=2)
+        cache = x
+    return x, cache
+
+
+import torch.fx
+torch.fx.wrap('preprocess_for_attn')
+
+
 class MultiHeadedAttentionSANMDecoder(nn.Module):
     def __init__(self, model):
         super().__init__()
@@ -73,16 +88,7 @@
         self.attn = None
 
     def forward(self, inputs, mask, cache=None):
-        # b, t, d = inputs.size()
-        # mask = torch.reshape(mask, (b, -1, 1))
-        inputs = inputs * mask
-
-        x = inputs.transpose(1, 2)
-        if cache is None:
-            x = self.pad_fn(x)
-        else:
-            x = torch.cat((cache[:, :, 1:], x), dim=2)
-            cache = x
+        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
         x = self.fsmn_block(x)
         x = x.transpose(1, 2)
 
@@ -232,4 +238,4 @@
         new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
         context_layer = context_layer.view(new_context_layer_shape)
         return self.linear_out(context_layer)  # (batch, time1, d_model)
-        
\ No newline at end of file
+        

--
Gitblit v1.9.1