Yabin Li
2023-08-08 b454a1054fadbff0ee963944ff42f66b98317582
funasr/export/models/encoder/sanm_encoder.py
@@ -8,6 +8,7 @@
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
from funasr.modules.embedding import StreamSinusoidalPositionEncoder
class SANMEncoder(nn.Module):
@@ -21,6 +22,8 @@
    ):
        super().__init__()
        self.embed = model.embed
        if isinstance(self.embed, StreamSinusoidalPositionEncoder):
            self.embed = None
        self.model = model
        self.feats_dim = feats_dim
        self._output_size = model._output_size
@@ -63,8 +66,10 @@
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                online: bool = False
                ):
        speech = speech * self._output_size ** 0.5
        if not online:
            speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        mask = self.prepare_mask(mask)
        if self.embed is None: