zhifu gao
2024-03-11 15c4709beb4b588db2135fc1133cd6955b5ef819
funasr/models/paraformer/decoder.py
@@ -628,14 +628,12 @@
                 ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
@@ -763,14 +761,12 @@
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
@@ -1036,14 +1032,12 @@
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.transformer.decoder import DecoderLayerExport