游雁
2024-09-25 fc547e14e818772811c3dccd9bb09e45e35df168
bugfix memory leaky
5个文件已修改
68 ■■■■ 已修改文件
funasr/models/lcbnet/attention.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/multihead_att.py 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sond/attention.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/attention.py 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/attention.py
@@ -78,19 +78,19 @@
            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
            min_value = torch.finfo(scores.dtype).min
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
        )  # (batch, time1, d_model)
        return self.linear_out(x), self.attn  # (batch, time1, d_model)
        return self.linear_out(x), attn  # (batch, time1, d_model)
    def forward(self, query, key, value, mask):
        """Compute scaled dot product attention.
funasr/models/sanm/multihead_att.py
@@ -55,8 +55,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -134,8 +134,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -177,8 +177,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -232,8 +232,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
funasr/models/sense_voice/model.py
@@ -196,13 +196,13 @@
                "inf"
            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -644,7 +644,13 @@
        self.embed = torch.nn.Embedding(
            7 + len(self.lid_dict) + len(self.textnorm_dict), input_size
        )
        self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
        self.emo_dict = {
            "unk": 25009,
            "happy": 25001,
            "sad": 25002,
            "angry": 25003,
            "neutral": 25004,
        }
        self.criterion_att = LabelSmoothingLoss(
            size=self.vocab_size,
funasr/models/sond/attention.py
@@ -84,13 +84,13 @@
            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -287,13 +287,13 @@
            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
funasr/models/transformer/attention.py
@@ -87,13 +87,13 @@
                "inf"
            )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -154,8 +154,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -209,8 +209,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -575,9 +575,9 @@
        if chunk_mask is not None:
            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
        scores = scores.masked_fill(mask, float("-inf"))
        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
        attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
        attn_output = self.dropout(self.attn)
        attn_output = self.dropout(attn)
        attn_output = torch.matmul(attn_output, value)
        attn_output = self.linear_out(