夜雨飘零
2023-12-19 53fccccb24d15d788919d91c8c2b06a115ddacf3
funasr/export/models/encoder/conformer_encoder.py
@@ -3,14 +3,14 @@
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.modules.attention import MultiHeadedAttentionSANM
from funasr.models.transformer.attention import MultiHeadedAttentionSANM
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as EncoderLayerConformer_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
from funasr.export.models.encoder.sanm_encoder import SANMEncoder
from funasr.modules.attention import RelPositionMultiHeadedAttention
from funasr.models.transformer.attention import RelPositionMultiHeadedAttention
# from funasr.export.models.modules.multihead_att import RelPositionMultiHeadedAttention as RelPositionMultiHeadedAttention_export
from funasr.export.models.modules.multihead_att import OnnxRelPosMultiHeadedAttention as RelPositionMultiHeadedAttention_export
@@ -61,7 +61,6 @@
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                ):
        speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        mask = self.prepare_mask(mask)
        if self.embed is None: