zhifu gao
2024-12-25 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9
funasr/models/sanm/decoder.py
@@ -226,7 +226,7 @@
        concat_after: bool = False,
        att_layer_num: int = 6,
        kernel_size: int = 21,
        sanm_shift: int = None,
        sanm_shfit: int = None,
        concat_embeds: bool = False,
        attention_dim: int = None,
        tf2torch_tensor_name_prefix_torch: str = "decoder",
@@ -271,14 +271,14 @@
        self.att_layer_num = att_layer_num
        self.num_blocks = num_blocks
        if sanm_shift is None:
            sanm_shift = (kernel_size - 1) // 2
        if sanm_shfit is None:
            sanm_shfit = (kernel_size - 1) // 2
        self.decoders = repeat(
            att_layer_num,
            lambda lnum: DecoderLayerSANM(
                attention_dim,
                MultiHeadedAttentionSANMDecoder(
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads,
@@ -303,7 +303,7 @@
                        attention_dim,
                        self_attention_dropout_rate,
                        kernel_size,
                        sanm_shift=sanm_shift,
                        sanm_shfit=sanm_shfit,
                    ),
                    None,
                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),