From 1af68ba6ffc21d4dc3bd5f01cda656def97e361c Mon Sep 17 00:00:00 2001
From: Nixon <2465004358@qq.com>
Date: 星期六, 14 九月 2024 10:13:23 +0800
Subject: [PATCH] fix bug, 1 fix cuda oom, 2 fix choose a window size 400 that is [2, 0] (#2075)
---
funasr/models/sanm/attention.py | 42 +++++++++++++++++++++---------------------
1 files changed, 21 insertions(+), 21 deletions(-)
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index c7e8a8e..47d60cb 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -104,13 +104,13 @@
"inf"
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -191,7 +191,7 @@
else:
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- self.attn = None
+ attn = None
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(
@@ -275,13 +275,13 @@
"inf"
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -400,8 +400,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -459,8 +459,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -683,18 +683,18 @@
# logging.info(
# "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
if ret_attn:
- return self.linear_out(x), self.attn # (batch, time1, d_model)
+ return self.linear_out(x), attn # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, x, memory, memory_mask, ret_attn=False):
@@ -782,14 +782,14 @@
def forward_attention(self, value, scores, mask, ret_attn):
scores = scores + mask.to(scores.device)
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
if ret_attn:
- return self.linear_out(context_layer), self.attn
+ return self.linear_out(context_layer), attn
return self.linear_out(context_layer) # (batch, time1, d_model)
@@ -868,13 +868,13 @@
"inf"
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
--
Gitblit v1.9.1