| | |
| | | from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export |
| | | from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward |
| | | from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export |
| | | from funasr.modules.embedding import StreamSinusoidalPositionEncoder |
| | | |
| | | |
| | | class SANMEncoder(nn.Module): |
| | |
| | | ): |
| | | super().__init__() |
| | | self.embed = model.embed |
| | | if isinstance(self.embed, StreamSinusoidalPositionEncoder): |
| | | self.embed = None |
| | | self.model = model |
| | | self.feats_dim = feats_dim |
| | | self._output_size = model._output_size |
| | |
| | | def forward(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | online: bool = False |
| | | ): |
| | | speech = speech * self._output_size ** 0.5 |
| | | if not online: |
| | | speech = speech * self._output_size ** 0.5 |
| | | mask = self.make_pad_mask(speech_lengths) |
| | | mask = self.prepare_mask(mask) |
| | | if self.embed is None: |