| | |
| | | concat_after: bool = False, |
| | | att_layer_num: int = 6, |
| | | kernel_size: int = 21, |
| | | sanm_shfit: int = 0, |
| | | sanm_shift: int = 0, |
| | | ): |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | |
| | | |
| | | self.att_layer_num = att_layer_num |
| | | self.num_blocks = num_blocks |
| | | if sanm_shfit is None: |
| | | sanm_shfit = (kernel_size - 1) // 2 |
| | | if sanm_shift is None: |
| | | sanm_shift = (kernel_size - 1) // 2 |
| | | self.decoders = repeat( |
| | | att_layer_num - 1, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | |
| | | self.last_decoder = ContextualDecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0 |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=0 |
| | | ), |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | |
| | | # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] |
| | | contextual_mask = self.make_pad_mask(contextual_length) |
| | | contextual_mask, _ = self.prepare_mask(contextual_mask) |
| | | # import pdb; pdb.set_trace() |
| | | contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1) |
| | | cx, tgt_mask, _, _, _ = self.bias_decoder( |
| | | x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask |