| | |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | else: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) |
| | | min_value = -float( |
| | | "inf" |
| | | ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) |
| | | qk = qk.masked_fill(mask, min_value) |
| | | |
| | | qk = qk.float() |
| | |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) |
| | | logp = torch.log_softmax(logp, dim=-1) |
| | | return logp.squeeze(0)[-1, :], state |
| | | |
| | | |