| | |
| | | |
| | | import torch |
| | | from torch import nn |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.modules.attention import MultiHeadedAttention |
| | |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | attention_dim = encoder_output_size |
| | | |
| | |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | |
| | | concat_after: bool = False, |
| | | embeds_id: int = -1, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | assert check_argument_types() |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | | "conv_kernel_length must have equal number of values to num_blocks: " |
| | |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | assert check_argument_types() |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | | "conv_kernel_length must have equal number of values to num_blocks: " |
| | |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | assert check_argument_types() |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | | "conv_kernel_length must have equal number of values to num_blocks: " |
| | |
| | | conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), |
| | | conv_usebias: int = False, |
| | | ): |
| | | assert check_argument_types() |
| | | if len(conv_kernel_length) != num_blocks: |
| | | raise ValueError( |
| | | "conv_kernel_length must have equal number of values to num_blocks: " |
| | |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | attention_dim = encoder_output_size |
| | | |
| | |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |