| | |
| | | concat_after: bool = False, |
| | | att_layer_num: int = 6, |
| | | kernel_size: int = 21, |
| | | sanm_shfit: int = 0, |
| | | sanm_shift: int = 0, |
| | | ): |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | |
| | | |
| | | self.att_layer_num = att_layer_num |
| | | self.num_blocks = num_blocks |
| | | if sanm_shfit is None: |
| | | sanm_shfit = (kernel_size - 1) // 2 |
| | | if sanm_shift is None: |
| | | sanm_shift = (kernel_size - 1) // 2 |
| | | self.decoders = repeat( |
| | | att_layer_num - 1, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | |
| | | self.last_decoder = ContextualDecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0 |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=0 |
| | | ), |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |