Zhanzhao (Deo) Liang
2024-12-25 8c7b7e5feb68fda1fc4ddd627bad0f915358149e
funasr/models/sanm/attention.py
@@ -104,13 +104,13 @@
                "inf"
            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -154,7 +154,7 @@
        n_feat,
        dropout_rate,
        kernel_size,
        sanm_shfit=0,
        sanm_shift=0,
        lora_list=None,
        lora_rank=8,
        lora_alpha=16,
@@ -191,7 +191,7 @@
        else:
            self.linear_out = nn.Linear(n_feat, n_feat)
            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
        self.attn = None
        attn = None
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fsmn_block = nn.Conv1d(
@@ -199,17 +199,17 @@
        )
        # padding
        left_padding = (kernel_size - 1) // 2
        if sanm_shfit > 0:
            left_padding = left_padding + sanm_shfit
        if sanm_shift > 0:
            left_padding = left_padding + sanm_shift
        right_padding = kernel_size - 1 - left_padding
        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
    def forward_fsmn(self, inputs, mask, mask_shift_chunk=None):
        b, t, d = inputs.size()
        if mask is not None:
            mask = torch.reshape(mask, (b, -1, 1))
            if mask_shfit_chunk is not None:
                mask = mask * mask_shfit_chunk
            if mask_shift_chunk is not None:
                mask = mask * mask_shift_chunk
            inputs = inputs * mask
        x = inputs.transpose(1, 2)
@@ -275,13 +275,13 @@
                "inf"
            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -289,7 +289,7 @@
        return self.linear_out(x)  # (batch, time1, d_model)
    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
    def forward(self, x, mask, mask_shift_chunk=None, mask_att_chunk_encoder=None):
        """Compute scaled dot product attention.
        Args:
@@ -304,7 +304,7 @@
        """
        q_h, k_h, v_h, v = self.forward_qkv(x)
        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
        fsmn_memory = self.forward_fsmn(v, mask, mask_shift_chunk)
        q_h = q_h * self.d_k ** (-0.5)
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
@@ -400,8 +400,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -459,8 +459,8 @@
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
@@ -478,7 +478,7 @@
    """
    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shift=0):
        """Construct an MultiHeadedAttention object."""
        super().__init__()
@@ -490,13 +490,13 @@
        # padding
        # padding
        left_padding = (kernel_size - 1) // 2
        if sanm_shfit > 0:
            left_padding = left_padding + sanm_shfit
        if sanm_shift > 0:
            left_padding = left_padding + sanm_shift
        right_padding = kernel_size - 1 - left_padding
        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
        self.kernel_size = kernel_size
    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
    def forward(self, inputs, mask, cache=None, mask_shift_chunk=None):
        """
        :param x: (#batch, time1, size).
        :param mask: Mask tensor (#batch, 1, time)
@@ -509,9 +509,9 @@
        if mask is not None:
            mask = torch.reshape(mask, (b, -1, 1))
            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
            if mask_shfit_chunk is not None:
                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
                mask = mask * mask_shfit_chunk
            if mask_shift_chunk is not None:
                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shift_chunk.size(), mask_shift_chunk[0:100:50, :, :]))
                mask = mask * mask_shift_chunk
            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
            # print("in fsmn, mask", mask.size())
            # print("in fsmn, inputs", inputs.size())
@@ -683,18 +683,18 @@
            # logging.info(
            #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
        )  # (batch, time1, d_model)
        if ret_attn:
            return self.linear_out(x), self.attn  # (batch, time1, d_model)
            return self.linear_out(x), attn  # (batch, time1, d_model)
        return self.linear_out(x)  # (batch, time1, d_model)
    def forward(self, x, memory, memory_mask, ret_attn=False):
@@ -780,16 +780,16 @@
        return q, k, v
    def forward_attention(self, value, scores, mask, ret_attn):
        scores = scores + mask
        scores = scores + mask.to(scores.device)
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        if ret_attn:
            return self.linear_out(context_layer), self.attn
            return self.linear_out(context_layer), attn
        return self.linear_out(context_layer)  # (batch, time1, d_model)
@@ -868,13 +868,13 @@
                "inf"
            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
            attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        p_attn = self.dropout(attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)