| | |
| | | |
| | | import torch |
| | | from torch import nn |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class LabelSmoothingLoss(nn.Module): |
| | |
| | | self.criterion = criterion |
| | | |
| | | def forward(self, pred, label, lengths): |
| | | pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]) |
| | | pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device) |
| | | loss = self.criterion(pred, label) |
| | | denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] |
| | | return loss.masked_fill(pad_mask, 0).sum() / denom |
| | | return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom |
| | | |
| | | |
| | | class NllLoss(nn.Module): |
| | |
| | | normalize_length=False, |
| | | criterion=nn.NLLLoss(reduction='none'), |
| | | ): |
| | | """Construct an LabelSmoothingLoss object.""" |
| | | """Construct an NllLoss object.""" |
| | | super(NllLoss, self).__init__() |
| | | self.criterion = criterion |
| | | self.padding_idx = padding_idx |