zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/mfcca/encoder_layer_mfcca.py
@@ -15,7 +15,6 @@
from torch.autograd import Variable
class Encoder_Conformer_Layer(nn.Module):
    """Encoder layer module.
@@ -111,7 +110,6 @@
        if self.normalize_before:
            x = self.norm_mha(x)
        if cache is None:
            x_q = x
        else:
@@ -165,8 +163,6 @@
        return x, mask
class EncoderLayer(nn.Module):
    """Encoder layer module.
@@ -217,10 +213,10 @@
                dropout_rate,
                normalize_before,
                concat_after,
                cca_pos=0)
            cca_pos=0,
        )
        self.norm_mha = LayerNorm(size)  # for the MHA module
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x_input, mask, channel_size, cache=None):
        """Compute encoded features.
@@ -247,8 +243,12 @@
        d_dim = x.size(2)
        x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
        x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
        pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
        pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
        pad_before = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
            x_new.type()
        )
        pad_after = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
            x_new.type()
        )
        x_pad = torch.cat([pad_before,x_new, pad_after], 1)
        x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
        x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
@@ -258,13 +258,16 @@
        x_new = x_new.reshape(-1,channel_size,d_dim)
        x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
        x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
        x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
        x_att = (
            x_att.reshape(-1, t_leng, channel_size, d_dim)
            .transpose(1, 2)
            .reshape(-1, t_leng, d_dim)
        )
        x = residual + self.dropout(x_att)
        if pos_emb is not None:
            x_input =  (x, pos_emb)
        else:
            x_input = x
        x_input, mask = self.encoder_csa(x_input, mask)
        return x_input, mask , channel_size