游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/models/sanm/attention.py
@@ -449,7 +449,7 @@
        return q_h, k_h, v_h
    def forward_attention(self, value, scores, mask):
    def forward_attention(self, value, scores, mask, ret_attn=False):
        """Compute attention context vector.
        Args:
@@ -476,16 +476,16 @@
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
        )  # (batch, time1, d_model)
        if ret_attn:
            return self.linear_out(x), self.attn  # (batch, time1, d_model)
        return self.linear_out(x)  # (batch, time1, d_model)
    def forward(self, x, memory, memory_mask):
    def forward(self, x, memory, memory_mask, ret_attn=False):
        """Compute scaled dot product attention.
        Args:
@@ -502,7 +502,7 @@
        q_h, k_h, v_h = self.forward_qkv(x, memory)
        q_h = q_h * self.d_k ** (-0.5)
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        return self.forward_attention(v_h, scores, memory_mask)
        return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn)
    def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
        """Compute scaled dot product attention.