| | |
| | | x = self.norm1(x) |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1) |
| | | x_concat = torch.cat( |
| | | (x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1 |
| | | ) |
| | | if self.in_size == self.size: |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks, |
| | | lambda lnum: EncoderLayer( |
| | | lambda lnum: ( |
| | | EncoderLayer( |
| | | output_size, |
| | | output_size, |
| | | MultiHeadSelfAttention( |
| | |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ) if lnum > 0 else EncoderLayer( |
| | | ) |
| | | if lnum > 0 |
| | | else EncoderLayer( |
| | | input_size, |
| | | output_size, |
| | | MultiHeadSelfAttention( |
| | |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ) |
| | | ), |
| | | ) |
| | | if self.normalize_before: |
| | |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |