liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/mfcca/encoder_layer_mfcca.py
@@ -15,7 +15,6 @@
from torch.autograd import Variable
class Encoder_Conformer_Layer(nn.Module):
    """Encoder layer module.
@@ -111,7 +110,6 @@
        if self.normalize_before:
            x = self.norm_mha(x)
        if cache is None:
            x_q = x
        else:
@@ -120,12 +118,12 @@
            residual = residual[:, -1:, :]
            mask = None if mask is None else mask[:, -1:, :]
        if self.cca_pos<2:
        if self.cca_pos < 2:
            if pos_emb is not None:
                x_att = self.self_attn(x_q, x, x, pos_emb, mask)
            else:
                x_att = self.self_attn(x_q, x, x, mask)
        else:
        else:
            x_att = self.self_attn(x_q, x, x, mask)
        if self.concat_after:
@@ -163,8 +161,6 @@
            return (x, pos_emb), mask
        return x, mask
class EncoderLayer(nn.Module):
@@ -209,18 +205,18 @@
        self.encoder_cros_channel_atten = self_attn_cros_channel
        self.encoder_csa = Encoder_Conformer_Layer(
                size,
                self_attn_conformer,
                feed_forward_csa,
                feed_forward_macaron_csa,
                conv_module_csa,
                dropout_rate,
                normalize_before,
                concat_after,
                cca_pos=0)
            size,
            self_attn_conformer,
            feed_forward_csa,
            feed_forward_macaron_csa,
            conv_module_csa,
            dropout_rate,
            normalize_before,
            concat_after,
            cca_pos=0,
        )
        self.norm_mha = LayerNorm(size)  # for the MHA module
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x_input, mask, channel_size, cache=None):
        """Compute encoded features.
@@ -245,26 +241,33 @@
        x = self.norm_mha(x)
        t_leng = x.size(1)
        d_dim = x.size(2)
        x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
        x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
        pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
        pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
        x_pad = torch.cat([pad_before,x_new, pad_after], 1)
        x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
        x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
        x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
        x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
        x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
        x_new = x_new.reshape(-1,channel_size,d_dim)
        x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
        x_new = x.reshape(-1, channel_size, t_leng, d_dim).transpose(1, 2)  # x_new B*T * C * D
        x_k_v = x_new.new(x_new.size(0), x_new.size(1), 5, x_new.size(2), x_new.size(3))
        pad_before = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
            x_new.type()
        )
        pad_after = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
            x_new.type()
        )
        x_pad = torch.cat([pad_before, x_new, pad_after], 1)
        x_k_v[:, :, 0, :, :] = x_pad[:, 0:-4, :, :]
        x_k_v[:, :, 1, :, :] = x_pad[:, 1:-3, :, :]
        x_k_v[:, :, 2, :, :] = x_pad[:, 2:-2, :, :]
        x_k_v[:, :, 3, :, :] = x_pad[:, 3:-1, :, :]
        x_k_v[:, :, 4, :, :] = x_pad[:, 4:, :, :]
        x_new = x_new.reshape(-1, channel_size, d_dim)
        x_k_v = x_k_v.reshape(-1, 5 * channel_size, d_dim)
        x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
        x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
        x_att = (
            x_att.reshape(-1, t_leng, channel_size, d_dim)
            .transpose(1, 2)
            .reshape(-1, t_leng, d_dim)
        )
        x = residual + self.dropout(x_att)
        if pos_emb is not None:
            x_input =  (x, pos_emb)
            x_input = (x, pos_emb)
        else:
            x_input = x
        x_input, mask = self.encoder_csa(x_input, mask)
        return x_input, mask , channel_size
        return x_input, mask, channel_size