hnluo
2023-04-03 85cf8c66967c2e1f8181ec4ff4d54b6cd26e21f9
funasr/models/decoder/sanm_decoder.py
@@ -94,7 +94,7 @@
        if self.self_attn:
            if self.normalize_before:
                tgt = self.norm2(tgt)
            x, _ = self.self_attn(tgt, tgt_mask)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + self.dropout(x)
        if self.src_attn is not None:
@@ -103,7 +103,6 @@
                x = self.norm3(x)
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache