夜雨飘零
2023-12-19 53fccccb24d15d788919d91c8c2b06a115ddacf3
funasr/export/models/encoder/sanm_encoder.py
@@ -3,11 +3,12 @@
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.modules.attention import MultiHeadedAttentionSANM
from funasr.models.transformer.attention import MultiHeadedAttentionSANM
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
from funasr.models.transformer.embedding import StreamSinusoidalPositionEncoder
class SANMEncoder(nn.Module):
@@ -21,6 +22,8 @@
    ):
        super().__init__()
        self.embed = model.embed
        if isinstance(self.embed, StreamSinusoidalPositionEncoder):
            self.embed = None
        self.model = model
        self.feats_dim = feats_dim
        self._output_size = model._output_size
@@ -63,8 +66,10 @@
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                online: bool = False
                ):
        speech = speech * self._output_size ** 0.5
        if not online:
            speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        mask = self.prepare_mask(mask)
        if self.embed is None: