zhifu gao
2024-03-11 9d48230c4f8f25bf88c5d6105f97370a36c9cf43
funasr/models/paraformer/decoder.py
@@ -635,8 +635,9 @@
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        
        from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):