From fc547e14e818772811c3dccd9bb09e45e35df168 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 九月 2024 15:26:14 +0800
Subject: [PATCH] bugfix memory leaky

---
 funasr/models/sanm/multihead_att.py |   16 ++++++++--------
 1 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/funasr/models/sanm/multihead_att.py b/funasr/models/sanm/multihead_att.py
index c7d9796..671d460 100644
--- a/funasr/models/sanm/multihead_att.py
+++ b/funasr/models/sanm/multihead_att.py
@@ -55,8 +55,8 @@
     def forward_attention(self, value, scores, mask):
         scores = scores + mask
 
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
 
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -134,8 +134,8 @@
     def forward_attention(self, value, scores, mask):
         scores = scores + mask
 
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
 
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -177,8 +177,8 @@
     def forward_attention(self, value, scores, mask):
         scores = scores + mask
 
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
 
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -232,8 +232,8 @@
     def forward_attention(self, value, scores, mask):
         scores = scores + mask
 
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
 
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

--
Gitblit v1.9.1