zhifu gao
2024-06-20 e65b1f701abca03bf3a1b5fbb200392aabd38c22
funasr/models/sanm/attention.py
@@ -780,7 +780,7 @@
        return q, k, v
    def forward_attention(self, value, scores, mask, ret_attn):
        scores = scores + mask
        scores = scores + mask.to(scores.device)
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)