zhifu gao
2024-12-25 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9
funasr/models/contextual_paraformer/decoder.py
@@ -137,7 +137,7 @@
        concat_after: bool = False,
        att_layer_num: int = 6,
        kernel_size: int = 21,
        sanm_shift: int = 0,
        sanm_shfit: int = 0,
    ):
        super().__init__(
            vocab_size=vocab_size,
@@ -179,14 +179,14 @@
        self.att_layer_num = att_layer_num
        self.num_blocks = num_blocks
        if sanm_shift is None:
            sanm_shift = (kernel_size - 1) // 2
        if sanm_shfit is None:
            sanm_shfit = (kernel_size - 1) // 2
        self.decoders = repeat(
            att_layer_num - 1,
            lambda lnum: DecoderLayerSANM(
                attention_dim,
                MultiHeadedAttentionSANMDecoder(
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads, attention_dim, src_attention_dropout_rate
@@ -210,7 +210,7 @@
        self.last_decoder = ContextualDecoderLayer(
            attention_dim,
            MultiHeadedAttentionSANMDecoder(
                attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift
                attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
            ),
            MultiHeadedAttentionCrossAtt(
                attention_heads, attention_dim, src_attention_dropout_rate
@@ -228,7 +228,7 @@
                lambda lnum: DecoderLayerSANM(
                    attention_dim,
                    MultiHeadedAttentionSANMDecoder(
                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=0
                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
                    ),
                    None,
                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),