| | |
| | | """ |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | if chunk_mask is not None: |
| | | memory_mask = memory_mask * chunk_mask |
| | | if tgt_mask.size(1) != memory_mask.size(1): |
| | | memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | x = tgt |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders( |