| | |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | # ys_pad: [sos, task, lid, text, eos] |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | mask = torch.ones_like(ys_pad) * (-1) # [sos, task, lid, text, eos]: [-1, -1, -1, -1] |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to( |
| | | torch.int64 |
| | | ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0] |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 # [-1, -1, lid, text, eos] |
| | | # decoder_out: [sos, task, lid, text] |
| | | # ys_pad_mask: [-1, lid, text, eos] |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |