游雁
2024-03-27 9b4e9cc8a0311e5243d69b73ed073e7ea441982e
funasr/models/sanm/attention.py
@@ -697,10 +697,10 @@
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, x, memory, memory_mask):
    def forward(self, x, memory, memory_mask, ret_attn=False):
        q, k, v = self.forward_qkv(x, memory)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, memory_mask)
        return self.forward_attention(v, scores, memory_mask, ret_attn)
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
@@ -717,7 +717,7 @@
        v = self.transpose_for_scores(v)
        return q, k, v
    def forward_attention(self, value, scores, mask):
    def forward_attention(self, value, scores, mask, ret_attn):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
@@ -726,6 +726,7 @@
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        if ret_attn: return self.linear_out(context_layer), self.attn
        return self.linear_out(context_layer)  # (batch, time1, d_model)
@@ -831,6 +832,3 @@
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
        return att_outs