From fc547e14e818772811c3dccd9bb09e45e35df168 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 九月 2024 15:26:14 +0800
Subject: [PATCH] bugfix memory leaky
---
funasr/models/sense_voice/model.py | 16 +++++--
funasr/models/sanm/multihead_att.py | 16 ++++----
funasr/models/sond/attention.py | 12 +++---
funasr/models/lcbnet/attention.py | 8 ++--
funasr/models/transformer/attention.py | 18 ++++----
5 files changed, 38 insertions(+), 32 deletions(-)
diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py
index 05a5041..83753ed 100644
--- a/funasr/models/lcbnet/attention.py
+++ b/funasr/models/lcbnet/attention.py
@@ -78,19 +78,19 @@
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
- return self.linear_out(x), self.attn # (batch, time1, d_model)
+ return self.linear_out(x), attn # (batch, time1, d_model)
def forward(self, query, key, value, mask):
"""Compute scaled dot product attention.
diff --git a/funasr/models/sanm/multihead_att.py b/funasr/models/sanm/multihead_att.py
index c7d9796..671d460 100644
--- a/funasr/models/sanm/multihead_att.py
+++ b/funasr/models/sanm/multihead_att.py
@@ -55,8 +55,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -134,8 +134,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -177,8 +177,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -232,8 +232,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 1311987..25e9faf 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -196,13 +196,13 @@
"inf"
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -644,7 +644,13 @@
self.embed = torch.nn.Embedding(
7 + len(self.lid_dict) + len(self.textnorm_dict), input_size
)
- self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
+ self.emo_dict = {
+ "unk": 25009,
+ "happy": 25001,
+ "sad": 25002,
+ "angry": 25003,
+ "neutral": 25004,
+ }
self.criterion_att = LabelSmoothingLoss(
size=self.vocab_size,
@@ -874,7 +880,7 @@
ctc_logits = self.ctc.log_softmax(encoder_out)
if kwargs.get("ban_emo_unk", False):
ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf")
-
+
results = []
b, n, d = encoder_out.size()
if isinstance(key[0], (list, tuple)):
diff --git a/funasr/models/sond/attention.py b/funasr/models/sond/attention.py
index 1af5534..18580b7 100644
--- a/funasr/models/sond/attention.py
+++ b/funasr/models/sond/attention.py
@@ -84,13 +84,13 @@
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -287,13 +287,13 @@
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
diff --git a/funasr/models/transformer/attention.py b/funasr/models/transformer/attention.py
index 6e6f754..2844333 100644
--- a/funasr/models/transformer/attention.py
+++ b/funasr/models/transformer/attention.py
@@ -87,13 +87,13 @@
"inf"
) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
- p_attn = self.dropout(self.attn)
+ p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -154,8 +154,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -209,8 +209,8 @@
def forward_attention(self, value, scores, mask):
scores = scores + mask
- self.attn = torch.softmax(scores, dim=-1)
- context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+ attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -575,9 +575,9 @@
if chunk_mask is not None:
mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
scores = scores.masked_fill(mask, float("-inf"))
- self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
- attn_output = self.dropout(self.attn)
+ attn_output = self.dropout(attn)
attn_output = torch.matmul(attn_output, value)
attn_output = self.linear_out(
--
Gitblit v1.9.1