志浩
2023-08-01 66880c2a1aeb3f94ce0020a71397e213beb9f3a0
TOLD/SOND: update SequenceBinaryCrossEntropy loss
1个文件已修改
4 ■■■■ 已修改文件
funasr/losses/label_smoothing_loss.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/losses/label_smoothing_loss.py
@@ -75,10 +75,10 @@
        self.criterion = criterion
    def forward(self, pred, label, lengths):
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
        loss = self.criterion(pred, label)
        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
        return loss.masked_fill(pad_mask, 0).sum() / denom
        return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
class NllLoss(nn.Module):