| funasr/losses/label_smoothing_loss.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/losses/label_smoothing_loss.py
@@ -75,10 +75,10 @@ self.criterion = criterion def forward(self, pred, label, lengths): pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]) pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device) loss = self.criterion(pred, label) denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] return loss.masked_fill(pad_mask, 0).sum() / denom return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom class NllLoss(nn.Module):