hnluo
2023-09-11 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b
funasr/modules/attention.py
@@ -11,7 +11,11 @@
import numpy
import torch
from torch import nn
from typing import Optional, Tuple
import torch.nn.functional as F
from funasr.modules.nets_utils import make_pad_mask
import funasr.modules.lora.layers as lora
class MultiHeadedAttention(nn.Module):
    """Multi-Head Attention layer.
@@ -318,7 +322,7 @@
    """
    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
        """Construct an MultiHeadedAttention object."""
        super(MultiHeadedAttentionSANM, self).__init__()
        assert n_feat % n_head == 0
@@ -328,8 +332,19 @@
        # self.linear_q = nn.Linear(n_feat, n_feat)
        # self.linear_k = nn.Linear(n_feat, n_feat)
        # self.linear_v = nn.Linear(n_feat, n_feat)
        self.linear_out = nn.Linear(n_feat, n_feat)
        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
        if lora_list is not None:
            if "o" in lora_list:
                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
            else:
                self.linear_out = nn.Linear(n_feat, n_feat)
            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
            if lora_qkv_list == [False, False, False]:
                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
            else:
                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
        else:
            self.linear_out = nn.Linear(n_feat, n_feat)
            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout_rate)
@@ -347,15 +362,17 @@
            mask = torch.reshape(mask, (b, -1, 1))
            if mask_shfit_chunk is not None:
                mask = mask * mask_shfit_chunk
            inputs = inputs * mask
        inputs = inputs * mask
        x = inputs.transpose(1, 2)
        x = self.pad_fn(x)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x += inputs
        x = self.dropout(x)
        return x * mask
        if mask is not None:
            x = x * mask
        return x
    def forward_qkv(self, x):
        """Transform query, key and value.
@@ -505,7 +522,7 @@
            # print("in fsmn, cache is None, x", x.size())
            x = self.pad_fn(x)
            if not self.training and t <= 1:
            if not self.training:
                cache = x
        else:
            # print("in fsmn, cache is not None, x", x.size())
@@ -513,7 +530,7 @@
            # if t < self.kernel_size:
            #     x = self.pad_fn(x)
            x = torch.cat((cache[:, :, 1:], x), dim=2)
            x = x[:, :, -self.kernel_size:]
            x = x[:, :, -(self.kernel_size+t-1):]
            # print("in fsmn, cache is not None, x_cat", x.size())
            cache = x
        x = self.fsmn_block(x)
@@ -538,18 +555,32 @@
    """
    def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
        """Construct an MultiHeadedAttention object."""
        super(MultiHeadedAttentionCrossAtt, self).__init__()
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.h = n_head
        self.linear_q = nn.Linear(n_feat, n_feat)
        # self.linear_k = nn.Linear(n_feat, n_feat)
        # self.linear_v = nn.Linear(n_feat, n_feat)
        self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
        self.linear_out = nn.Linear(n_feat, n_feat)
        if lora_list is not None:
            if "q" in lora_list:
                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
            else:
                self.linear_q = nn.Linear(n_feat, n_feat)
            lora_kv_list = ["k" in lora_list, "v" in lora_list]
            if lora_kv_list == [False, False]:
                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
            else:
                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
            if "o" in lora_list:
                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
            else:
                self.linear_out = nn.Linear(n_feat, n_feat)
        else:
            self.linear_q = nn.Linear(n_feat, n_feat)
            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
            self.linear_out = nn.Linear(n_feat, n_feat)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout_rate)
@@ -739,3 +770,255 @@
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
        return att_outs
class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
    """RelPositionMultiHeadedAttention definition.
    Args:
        num_heads: Number of attention heads.
        embed_size: Embedding size.
        dropout_rate: Dropout rate.
    """
    def __init__(
        self,
        num_heads: int,
        embed_size: int,
        dropout_rate: float = 0.0,
        simplified_attention_score: bool = False,
    ) -> None:
        """Construct an MultiHeadedAttention object."""
        super().__init__()
        self.d_k = embed_size // num_heads
        self.num_heads = num_heads
        assert self.d_k * num_heads == embed_size, (
            "embed_size (%d) must be divisible by num_heads (%d)",
            (embed_size, num_heads),
        )
        self.linear_q = torch.nn.Linear(embed_size, embed_size)
        self.linear_k = torch.nn.Linear(embed_size, embed_size)
        self.linear_v = torch.nn.Linear(embed_size, embed_size)
        self.linear_out = torch.nn.Linear(embed_size, embed_size)
        if simplified_attention_score:
            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
            self.compute_att_score = self.compute_simplified_attention_score
        else:
            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
            torch.nn.init.xavier_uniform_(self.pos_bias_u)
            torch.nn.init.xavier_uniform_(self.pos_bias_v)
            self.compute_att_score = self.compute_attention_score
        self.dropout = torch.nn.Dropout(p=dropout_rate)
        self.attn = None
    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
        """Compute relative positional encoding.
        Args:
            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
            left_context: Number of frames in left context.
        Returns:
            x: Output sequence. (B, H, T_1, T_2)
        """
        batch_size, n_heads, time1, n = x.shape
        time2 = time1 + left_context
        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
        return x.as_strided(
            (batch_size, n_heads, time1, time2),
            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
            storage_offset=(n_stride * (time1 - 1)),
        )
    def compute_simplified_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Simplified attention score computation.
        Reference: https://github.com/k2-fsa/icefall/pull/458
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        pos_enc = self.linear_pos(pos_enc)
        matrix_ac = torch.matmul(query, key.transpose(2, 3))
        matrix_bd = self.rel_shift(
            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
            left_context=left_context,
        )
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def compute_attention_score(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        pos_enc: torch.Tensor,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Attention score computation.
        Args:
            query: Transformed query tensor. (B, H, T_1, d_k)
            key: Transformed key tensor. (B, H, T_2, d_k)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            left_context: Number of frames in left context.
        Returns:
            : Attention score. (B, H, T_1, T_2)
        """
        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
        query = query.transpose(1, 2)
        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
    def forward_qkv(
        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Transform query, key and value.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            v: Value tensor. (B, T_2, size)
        Returns:
            q: Transformed query tensor. (B, H, T_1, d_k)
            k: Transformed key tensor. (B, H, T_2, d_k)
            v: Transformed value tensor. (B, H, T_2, d_k)
        """
        n_batch = query.size(0)
        q = (
            self.linear_q(query)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        k = (
            self.linear_k(key)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        v = (
            self.linear_v(value)
            .view(n_batch, -1, self.num_heads, self.d_k)
            .transpose(1, 2)
        )
        return q, k, v
    def forward_attention(
        self,
        value: torch.Tensor,
        scores: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Compute attention context vector.
        Args:
            value: Transformed value. (B, H, T_2, d_k)
            scores: Attention score. (B, H, T_1, T_2)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
        Returns:
           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
        """
        batch_size = scores.size(0)
        mask = mask.unsqueeze(1).unsqueeze(2)
        if chunk_mask is not None:
            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
        scores = scores.masked_fill(mask, float("-inf"))
        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
        attn_output = self.dropout(self.attn)
        attn_output = torch.matmul(attn_output, value)
        attn_output = self.linear_out(
            attn_output.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.num_heads * self.d_k)
        )
        return attn_output
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        chunk_mask: Optional[torch.Tensor] = None,
        left_context: int = 0,
    ) -> torch.Tensor:
        """Compute scaled dot product attention with rel. positional encoding.
        Args:
            query: Query tensor. (B, T_1, size)
            key: Key tensor. (B, T_2, size)
            value: Value tensor. (B, T_2, size)
            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
            mask: Source mask. (B, T_2)
            chunk_mask: Chunk mask. (T_1, T_1)
            left_context: Number of frames in left context.
        Returns:
            : Output tensor. (B, T_1, H * d_k)
        """
        q, k, v = self.forward_qkv(query, key, value)
        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
class CosineDistanceAttention(nn.Module):
    """ Compute Cosine Distance between spk decoder output and speaker profile
    Args:
        profile_path: speaker profile file path (.npy file)
    """
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, spk_decoder_out, profile, profile_lens=None):
        """
        Args:
            spk_decoder_out(torch.Tensor):(B, L, D)
            spk_profiles(torch.Tensor):(B, N, D)
        """
        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
        if profile_lens is not None:
            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
            min_value = float(
                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
            )
            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
        else:
            x = x[:, -1:, :, :]
            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
        return spk_embedding, weights