| | |
| | | concat_after: bool = False, |
| | | att_layer_num: int = 6, |
| | | kernel_size: int = 21, |
| | | sanm_shfit: int = 0, |
| | | sanm_shift: int = 0, |
| | | lora_list: List[str] = None, |
| | | lora_rank: int = 8, |
| | | lora_alpha: int = 16, |
| | |
| | | |
| | | self.att_layer_num = att_layer_num |
| | | self.num_blocks = num_blocks |
| | | if sanm_shfit is None: |
| | | sanm_shfit = (kernel_size - 1) // 2 |
| | | if sanm_shift is None: |
| | | sanm_shift = (kernel_size - 1) // 2 |
| | | self.decoders = repeat( |
| | | att_layer_num, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, |
| | |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0 |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=0 |
| | | ), |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | |
| | | for _ in range(cache_num) |
| | | ] |
| | | return (tgt, memory, pre_acoustic_embeds, cache) |
| | | |
| | | |
| | | def is_optimizable(self): |
| | | return True |
| | | |
| | | |
| | | def get_input_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ['tgt', 'memory', 'pre_acoustic_embeds'] \ |
| | | + ['cache_%d' % i for i in range(cache_num)] |
| | | |
| | | |
| | | def get_output_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ['y'] \ |
| | | + ['out_cache_%d' % i for i in range(cache_num)] |
| | | |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | 'tgt': { |