Lizerui9926
2023-04-03 cbbb3007436bf7cefa45af9a5aad60c696da9732
funasr/models/decoder/sanm_decoder.py
@@ -94,7 +94,7 @@
        if self.self_attn:
            if self.normalize_before:
                tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x, _ = self.self_attn(tgt, tgt_mask)
            x = residual + self.dropout(x)
        if self.src_attn is not None:
@@ -399,7 +399,7 @@
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            c = cache[i]
            x, tgt_mask, memory, memory_mask, c_ret = decoder(
            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                x, tgt_mask, memory, memory_mask, cache=c
            )
            new_cache.append(c_ret)
@@ -409,13 +409,13 @@
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                c = cache[j]
                x, tgt_mask, memory, memory_mask, c_ret = decoder(
                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                    x, tgt_mask, memory, memory_mask, cache=c
                )
                new_cache.append(c_ret)
        for decoder in self.decoders3:
            x, tgt_mask, memory, memory_mask, _ = decoder(
            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                x, tgt_mask, memory, None, cache=None
            )
@@ -1076,7 +1076,7 @@
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            c = cache[i]
            x, tgt_mask, memory, memory_mask, c_ret = decoder(
            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                x, tgt_mask, memory, None, cache=c
            )
            new_cache.append(c_ret)
@@ -1086,14 +1086,14 @@
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                c = cache[j]
                x, tgt_mask, memory, memory_mask, c_ret = decoder(
                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                    x, tgt_mask, memory, None, cache=c
                )
                new_cache.append(c_ret)
        for decoder in self.decoders3:
            x, tgt_mask, memory, memory_mask, _ = decoder(
            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                x, tgt_mask, memory, None, cache=None
            )