| | |
| | | from torch.autograd import Variable |
| | | |
| | | |
| | | |
| | | class Encoder_Conformer_Layer(nn.Module): |
| | | """Encoder layer module. |
| | | |
| | |
| | | if self.normalize_before: |
| | | x = self.norm_mha(x) |
| | | |
| | | |
| | | if cache is None: |
| | | x_q = x |
| | | else: |
| | |
| | | return x, mask |
| | | |
| | | |
| | | |
| | | |
| | | class EncoderLayer(nn.Module): |
| | | """Encoder layer module. |
| | | |
| | |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | cca_pos=0) |
| | | cca_pos=0, |
| | | ) |
| | | self.norm_mha = LayerNorm(size) # for the MHA module |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | |
| | | |
| | | def forward(self, x_input, mask, channel_size, cache=None): |
| | | """Compute encoded features. |
| | |
| | | d_dim = x.size(2) |
| | | x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D |
| | | x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3)) |
| | | pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type()) |
| | | pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type()) |
| | | pad_before = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type( |
| | | x_new.type() |
| | | ) |
| | | pad_after = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type( |
| | | x_new.type() |
| | | ) |
| | | x_pad = torch.cat([pad_before,x_new, pad_after], 1) |
| | | x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:] |
| | | x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:] |
| | |
| | | x_new = x_new.reshape(-1,channel_size,d_dim) |
| | | x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim) |
| | | x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None) |
| | | x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim) |
| | | x_att = ( |
| | | x_att.reshape(-1, t_leng, channel_size, d_dim) |
| | | .transpose(1, 2) |
| | | .reshape(-1, t_leng, d_dim) |
| | | ) |
| | | x = residual + self.dropout(x_att) |
| | | if pos_emb is not None: |
| | | x_input = (x, pos_emb) |
| | | else: |
| | | x_input = x |
| | | x_input, mask = self.encoder_csa(x_input, mask) |
| | | |
| | | |
| | | return x_input, mask , channel_size |