| | |
| | | |
| | | |
| | | class SequenceBinaryCrossEntropy(nn.Module): |
| | | def __init__( |
| | | self, |
| | | normalize_length=False, |
| | | criterion=nn.BCEWithLogitsLoss(reduction="none") |
| | | ): |
| | | def __init__(self, normalize_length=False, criterion=nn.BCEWithLogitsLoss(reduction="none")): |
| | | super().__init__() |
| | | self.normalize_length = normalize_length |
| | | self.criterion = criterion |
| | |
| | | size, |
| | | padding_idx, |
| | | normalize_length=False, |
| | | criterion=nn.NLLLoss(reduction='none'), |
| | | criterion=nn.NLLLoss(reduction="none"), |
| | | ): |
| | | """Construct an NllLoss object.""" |
| | | super(NllLoss, self).__init__() |
| | |
| | | ignore = target == self.padding_idx # (B,) |
| | | total = len(target) - ignore.sum().item() |
| | | target = target.masked_fill(ignore, 0) # avoid -1 index |
| | | kl = self.criterion(x , target) |
| | | kl = self.criterion(x, target) |
| | | denom = total if self.normalize_length else batch_size |
| | | return kl.masked_fill(ignore, 0).sum() / denom |