From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update
---
funasr/models/sanm/attention.py | 7 ++++---
1 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 5f91268..1768bbd 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -697,10 +697,10 @@
self.attn = None
self.all_head_size = self.h * self.d_k
- def forward(self, x, memory, memory_mask):
+ def forward(self, x, memory, memory_mask, ret_attn=False):
q, k, v = self.forward_qkv(x, memory)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
- return self.forward_attention(v, scores, memory_mask)
+ return self.forward_attention(v, scores, memory_mask, ret_attn)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.h, self.d_k)
@@ -717,7 +717,7 @@
v = self.transpose_for_scores(v)
return q, k, v
- def forward_attention(self, value, scores, mask):
+ def forward_attention(self, value, scores, mask, ret_attn):
scores = scores + mask
self.attn = torch.softmax(scores, dim=-1)
@@ -726,6 +726,7 @@
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
+ if ret_attn: return self.linear_out(context_layer), self.attn
return self.linear_out(context_layer) # (batch, time1, d_model)
--
Gitblit v1.9.1