liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/sanm/attention.py
@@ -780,7 +780,7 @@
        return q, k, v
    def forward_attention(self, value, scores, mask, ret_attn):
        scores = scores + mask
        scores = scores + mask.to(scores.device)
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)