| | |
| | | import torch.nn as nn |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | from funasr.modules.attention import MultiHeadedAttentionSANM |
| | | from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export |
| | | from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export |
| | |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | model_name='encoder', |
| | | onnx: bool = True, |
| | | ): |
| | | super().__init__() |
| | | self.embed = model.embed |
| | | self.model = model |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | self.feats_dim = feats_dim |
| | | |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | if hasattr(model, 'encoders0'): |
| | | for i, d in enumerate(self.model.encoders0): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANM): |