游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/losses/label_smoothing_loss.py
@@ -8,7 +8,7 @@
import torch
from torch import nn
from funasr.modules.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
class LabelSmoothingLoss(nn.Module):
@@ -75,10 +75,10 @@
        self.criterion = criterion
    def forward(self, pred, label, lengths):
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
        loss = self.criterion(pred, label)
        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
        return loss.masked_fill(pad_mask, 0).sum() / denom
        return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
class NllLoss(nn.Module):
@@ -97,7 +97,7 @@
        normalize_length=False,
        criterion=nn.NLLLoss(reduction='none'),
    ):
        """Construct an LabelSmoothingLoss object."""
        """Construct an NllLoss object."""
        super(NllLoss, self).__init__()
        self.criterion = criterion
        self.padding_idx = padding_idx