| | |
| | | ) |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class EBranchformerEncoderLayer(torch.nn.Module): |
| | | """E-Branchformer encoder layer module. |
| | | |
| | |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | @tables.register("encoder_classes", "EBranchformerEncoder") |
| | | class EBranchformerEncoder(nn.Module): |
| | | """E-Branchformer encoder module.""" |
| | |
| | | elif pos_enc_layer_type == "legacy_rel_pos": |
| | | assert attention_layer_type == "legacy_rel_selfattn" |
| | | pos_enc_class = LegacyRelPositionalEncoding |
| | | logging.warning( |
| | | "Using legacy_rel_pos and it will be deprecated in the future." |
| | | ) |
| | | logging.warning("Using legacy_rel_pos and it will be deprecated in the future.") |
| | | else: |
| | | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) |
| | | |
| | |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | logging.warning( |
| | | "Using legacy_rel_selfattn and it will be deprecated in the future." |
| | | ) |
| | | logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.") |
| | | elif attention_layer_type == "rel_selfattn": |
| | | assert pos_enc_layer_type == "rel_pos" |
| | | encoder_selfattn_layer = RelPositionMultiHeadedAttention |
| | |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | cgmlp_layer(*cgmlp_layer_args), |
| | | positionwise_layer(*positionwise_layer_args) if use_ffn else None, |
| | | positionwise_layer(*positionwise_layer_args) |
| | | if use_ffn and macaron_ffn |
| | | else None, |
| | | positionwise_layer(*positionwise_layer_args) if use_ffn and macaron_ffn else None, |
| | | dropout_rate, |
| | | merge_conv_kernel, |
| | | ), |