funasr/models/decoder/sanm_decoder.py
@@ -94,7 +94,7 @@ if self.self_attn: if self.normalize_before: tgt = self.norm2(tgt) x, _ = self.self_attn(tgt, tgt_mask) x, cache = self.self_attn(tgt, tgt_mask, cache=cache) x = residual + self.dropout(x) if self.src_attn is not None: @@ -103,7 +103,6 @@ x = self.norm3(x) x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) return x, tgt_mask, memory, memory_mask, cache