| | |
| | | import torch.nn as nn |
| | | |
| | | |
| | | # from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | |
| | | from funasr.modules.attention import MultiHeadedAttentionSANMDecoder |
| | | from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export |
| | |
| | | class ParaformerSANMDecoder(nn.Module): |
| | | def __init__(self, model, |
| | | max_seq_len=512, |
| | | model_name='decoder'): |
| | | model_name='decoder', |
| | | onnx: bool = True,): |
| | | super().__init__() |
| | | # self.embed = model.embed #Embedding(model.embed, max_seq_len) |
| | | self.model = model |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | for i, d in enumerate(self.model.decoders): |
| | | if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | |
| | | self.output_layer = model.output_layer |
| | | self.after_norm = model.after_norm |
| | | self.model_name = model_name |
| | | |
| | | |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |