From 23e7ddebccd3b05cf7ef89809bcfe565ad6dfa1f Mon Sep 17 00:00:00 2001
From: majic31 <majic31@163.com>
Date: 星期二, 24 十二月 2024 10:00:14 +0800
Subject: [PATCH] Fix the variable name (#2328)
---
funasr/models/sanm/attention.py | 60 ++++++++++++++++++++++++++++++++++--------------------------
1 files changed, 34 insertions(+), 26 deletions(-)
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index da8850f..47d60cb 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -100,15 +100,17 @@
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -189,7 +191,7 @@
else:
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
- self.attn = None
+ attn = None
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(
@@ -269,15 +271,17 @@
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -396,8 +400,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -455,8 +459,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -673,22 +677,24 @@
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
# logging.info(
# "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
if ret_attn:
- return self.linear_out(x), self.attn # (batch, time1, d_model)
+ return self.linear_out(x), attn # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, x, memory, memory_mask, ret_attn=False):
@@ -774,16 +780,16 @@
return q, k, v
def forward_attention(self, value, scores, mask, ret_attn):
- scores = scores + mask
+ scores = scores + mask.to(scores.device)
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
if ret_attn:
- return self.linear_out(context_layer), self.attn
+ return self.linear_out(context_layer), attn
return self.linear_out(context_layer) # (batch, time1, d_model)
@@ -858,15 +864,17 @@
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+ min_value = -float(
+ "inf"
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
--
Gitblit v1.9.1