zhifu gao
2023-02-27 8cc5bbf99a59694228aafcbe8712e09b9a4cb26b
funasr/losses/label_smoothing_loss.py
@@ -8,6 +8,7 @@
import torch
from torch import nn
from funasr.modules.nets_utils import make_pad_mask
class LabelSmoothingLoss(nn.Module):
@@ -61,3 +62,20 @@
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
class SequenceBinaryCrossEntropy(nn.Module):
    def __init__(
            self,
            normalize_length=False,
            criterion=nn.BCEWithLogitsLoss(reduction="none")
    ):
        super().__init__()
        self.normalize_length = normalize_length
        self.criterion = criterion
    def forward(self, pred, label, lengths):
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
        loss = self.criterion(pred, label)
        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
        return loss.masked_fill(pad_mask, 0).sum() / denom