Merge pull request #324 from alibaba-damo-academy/dev_lhn2
fix decoder cache
| | |
| | | if self.self_attn: |
| | | if self.normalize_before: |
| | | tgt = self.norm2(tgt) |
| | | x, cache = self.self_attn(tgt, tgt_mask, cache=cache) |
| | | x, _ = self.self_attn(tgt, tgt_mask) |
| | | x = residual + self.dropout(x) |
| | | |
| | | if self.src_attn is not None: |
| | |
| | | for i in range(self.att_layer_num): |
| | | decoder = self.decoders[i] |
| | | c = cache[i] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, memory_mask, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | |
| | | j = i + self.att_layer_num |
| | | decoder = self.decoders2[i] |
| | | c = cache[j] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, memory_mask, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | for decoder in self.decoders3: |
| | | x, tgt_mask, memory, memory_mask, _ = decoder( |
| | | x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=None |
| | | ) |
| | | |
| | |
| | | for i in range(self.att_layer_num): |
| | | decoder = self.decoders[i] |
| | | c = cache[i] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | |
| | | j = i + self.att_layer_num |
| | | decoder = self.decoders2[i] |
| | | c = cache[j] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder( |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | for decoder in self.decoders3: |
| | | |
| | | x, tgt_mask, memory, memory_mask, _ = decoder( |
| | | x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk( |
| | | x, tgt_mask, memory, None, cache=None |
| | | ) |
| | | |