| | |
| | | from torch.autograd import Variable |
| | | |
| | | |
| | | |
| | | class Encoder_Conformer_Layer(nn.Module): |
| | | """Encoder layer module. |
| | | |
| | |
| | | if self.normalize_before: |
| | | x = self.norm_mha(x) |
| | | |
| | | |
| | | if cache is None: |
| | | x_q = x |
| | | else: |
| | |
| | | residual = residual[:, -1:, :] |
| | | mask = None if mask is None else mask[:, -1:, :] |
| | | |
| | | if self.cca_pos<2: |
| | | if self.cca_pos < 2: |
| | | if pos_emb is not None: |
| | | x_att = self.self_attn(x_q, x, x, pos_emb, mask) |
| | | else: |
| | | x_att = self.self_attn(x_q, x, x, mask) |
| | | else: |
| | | else: |
| | | x_att = self.self_attn(x_q, x, x, mask) |
| | | |
| | | if self.concat_after: |
| | |
| | | return (x, pos_emb), mask |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | |
| | | |
| | | class EncoderLayer(nn.Module): |
| | |
| | | |
| | | self.encoder_cros_channel_atten = self_attn_cros_channel |
| | | self.encoder_csa = Encoder_Conformer_Layer( |
| | | size, |
| | | self_attn_conformer, |
| | | feed_forward_csa, |
| | | feed_forward_macaron_csa, |
| | | conv_module_csa, |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | cca_pos=0) |
| | | size, |
| | | self_attn_conformer, |
| | | feed_forward_csa, |
| | | feed_forward_macaron_csa, |
| | | conv_module_csa, |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | cca_pos=0, |
| | | ) |
| | | self.norm_mha = LayerNorm(size) # for the MHA module |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | |
| | | |
| | | def forward(self, x_input, mask, channel_size, cache=None): |
| | | """Compute encoded features. |
| | |
| | | x = self.norm_mha(x) |
| | | t_leng = x.size(1) |
| | | d_dim = x.size(2) |
| | | x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D |
| | | x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3)) |
| | | pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type()) |
| | | pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type()) |
| | | x_pad = torch.cat([pad_before,x_new, pad_after], 1) |
| | | x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:] |
| | | x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:] |
| | | x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:] |
| | | x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:] |
| | | x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:] |
| | | x_new = x_new.reshape(-1,channel_size,d_dim) |
| | | x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim) |
| | | x_new = x.reshape(-1, channel_size, t_leng, d_dim).transpose(1, 2) # x_new B*T * C * D |
| | | x_k_v = x_new.new(x_new.size(0), x_new.size(1), 5, x_new.size(2), x_new.size(3)) |
| | | pad_before = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type( |
| | | x_new.type() |
| | | ) |
| | | pad_after = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type( |
| | | x_new.type() |
| | | ) |
| | | x_pad = torch.cat([pad_before, x_new, pad_after], 1) |
| | | x_k_v[:, :, 0, :, :] = x_pad[:, 0:-4, :, :] |
| | | x_k_v[:, :, 1, :, :] = x_pad[:, 1:-3, :, :] |
| | | x_k_v[:, :, 2, :, :] = x_pad[:, 2:-2, :, :] |
| | | x_k_v[:, :, 3, :, :] = x_pad[:, 3:-1, :, :] |
| | | x_k_v[:, :, 4, :, :] = x_pad[:, 4:, :, :] |
| | | x_new = x_new.reshape(-1, channel_size, d_dim) |
| | | x_k_v = x_k_v.reshape(-1, 5 * channel_size, d_dim) |
| | | x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None) |
| | | x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim) |
| | | x_att = ( |
| | | x_att.reshape(-1, t_leng, channel_size, d_dim) |
| | | .transpose(1, 2) |
| | | .reshape(-1, t_leng, d_dim) |
| | | ) |
| | | x = residual + self.dropout(x_att) |
| | | if pos_emb is not None: |
| | | x_input = (x, pos_emb) |
| | | x_input = (x, pos_emb) |
| | | else: |
| | | x_input = x |
| | | x_input, mask = self.encoder_csa(x_input, mask) |
| | | |
| | | |
| | | return x_input, mask , channel_size |
| | | return x_input, mask, channel_size |