From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update

---
 funasr/models/sanm/attention.py |    7 ++++---
 1 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 5f91268..1768bbd 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -697,10 +697,10 @@
         self.attn = None
         self.all_head_size = self.h * self.d_k
 
-    def forward(self, x, memory, memory_mask):
+    def forward(self, x, memory, memory_mask, ret_attn=False):
         q, k, v = self.forward_qkv(x, memory)
         scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
-        return self.forward_attention(v, scores, memory_mask)
+        return self.forward_attention(v, scores, memory_mask, ret_attn)
 
     def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
         new_x_shape = x.size()[:-1] + (self.h, self.d_k)
@@ -717,7 +717,7 @@
         v = self.transpose_for_scores(v)
         return q, k, v
 
-    def forward_attention(self, value, scores, mask):
+    def forward_attention(self, value, scores, mask, ret_attn):
         scores = scores + mask
 
         self.attn = torch.softmax(scores, dim=-1)
@@ -726,6 +726,7 @@
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
         context_layer = context_layer.view(new_context_layer_shape)
+        if ret_attn: return self.linear_out(context_layer), self.attn
         return self.linear_out(context_layer)  # (batch, time1, d_model)
 
 

--
Gitblit v1.9.1