游雁
2024-02-20 d79287c37e4e7ae2694a992cbbfb03a5ca4f7670
funasr/models/transformer/attention.py
@@ -312,8 +312,221 @@
        return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
    """RelPositionMultiHeadedAttention definition.
    Args:
        num_heads: Number of attention heads.
        embed_size: Embedding size.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self,
        num_heads: int,
        embed_size: int,
        dropout_rate: float = 0.0,
        simplified_attention_score: bool = False,
    ) -> None:
        """Construct an MultiHeadedAttention object."""
        super().__init__()
        self.d_k = embed_size // num_heads
        self.num_heads = num_heads
        assert self.d_k * num_heads == embed_size, (
            "embed_size (%d) must be divisible by num_heads (%d)",
            (embed_size, num_heads),
        )
        self.linear_q = torch.nn.Linear(embed_size, embed_size)
        self.linear_k = torch.nn.Linear(embed_size, embed_size)
        self.linear_v = torch.nn.Linear(embed_size, embed_size)
        self.linear_out = torch.nn.Linear(embed_size, embed_size)
        if simplified_attention_score:
            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
            self.compute_att_score = self.compute_simplified_attention_score
        else:
            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            torch.nn.init.xavier_uniform_(self.pos_bias_u)
            torch.nn.init.xavier_uniform_(self.pos_bias_v)
            self.compute_att_score = self.compute_attention_score
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.attn = None
    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
        """Compute relative positional encoding.
        Args:
            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
            left_context: Number of frames in left context.
        Returns:
            x: Output sequence. (B, H, T_1, T_2)
        """
        batch_size, n_heads, time1, n = x.shape
        time2 = time1 + left_context
        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
        return x.as_strided(
            (batch_size, n_heads, time1, time2),
            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
            storage_offset=(n_stride * (time1 - 1)),
        )
    def compute_simplified_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Simplified attention score computation.
        Reference: https://github.com/k2-fsa/icefall/pull/458
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        pos_enc = self.linear_pos(pos_enc)
        matrix_ac = torch.matmul(query, key.transpose(2, 3))
        matrix_bd = self.rel_shift(
            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
            left_context=left_context,
        )
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def compute_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Attention score computation.
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
        query = query.transpose(1, 2)
        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def forward_qkv(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Transform query, key and value.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            v: Value tensor. (B, T_2, size)
        Returns:
            q: Transformed query tensor. (B, H, T_1, d_k)
            k: Transformed key tensor. (B, H, T_2, d_k)
            v: Transformed value tensor. (B, H, T_2, d_k)
        """
        n_batch = query.size(0)
        q = (
            self.linear_q(query)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        k = (
            self.linear_k(key)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        v = (
            self.linear_v(value)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        return q, k, v
    def forward_attention(
        self,
        value: torch.Tensor,
        scores: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Compute attention context vector.
        Args:
            value: Transformed value. (B, H, T_2, d_k)
            scores: Attention score. (B, H, T_1, T_2)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
        Returns:
           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
        """
        batch_size = scores.size(0)
        mask = mask.unsqueeze(1).unsqueeze(2)
        if chunk_mask is not None:
            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
        scores = scores.masked_fill(mask, float("-inf"))
        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
        attn_output = self.dropout(self.attn)
        attn_output = torch.matmul(attn_output, value)
        attn_output = self.linear_out(
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.num_heads * self.d_k)
        )
        return attn_output
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Compute scaled dot product attention with rel. positional encoding.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            value: Value tensor. (B, T_2, size)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
            left_context: Number of frames in left context.
        Returns:
            : Output tensor. (B, T_1, H * d_k)
        """
        q, k, v = self.forward_qkv(query, key, value)
        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)