| | |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | from funasr.modules.attention import MultiHeadedAttentionSANM |
| | | from funasr.models.transformer.attention import MultiHeadedAttentionSANM |
| | | from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export |
| | | from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export |
| | | from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward |
| | | from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward |
| | | from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export |
| | | from funasr.models.transformer.embedding import StreamSinusoidalPositionEncoder |
| | | |
| | | |
| | | class SANMEncoder(nn.Module): |
| | |
| | | ): |
| | | super().__init__() |
| | | self.embed = model.embed |
| | | if isinstance(self.embed, StreamSinusoidalPositionEncoder): |
| | | self.embed = None |
| | | self.model = model |
| | | self.feats_dim = feats_dim |
| | | self._output_size = model._output_size |
| | |
| | | def forward(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | online: bool = False |
| | | ): |
| | | speech = speech * self._output_size ** 0.5 |
| | | if not online: |
| | | speech = speech * self._output_size ** 0.5 |
| | | mask = self.make_pad_mask(speech_lengths) |
| | | mask = self.prepare_mask(mask) |
| | | if self.embed is None: |