| | |
| | | |
| | | import torch |
| | | from torch import nn |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | |
| | | |
| | | class LabelSmoothingLoss(nn.Module): |
| | |
| | | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) |
| | | denom = total if self.normalize_length else batch_size |
| | | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom |
| | | |
| | | |
| | | class SequenceBinaryCrossEntropy(nn.Module): |
| | | def __init__( |
| | | self, |
| | | normalize_length=False, |
| | | criterion=nn.BCEWithLogitsLoss(reduction="none") |
| | | ): |
| | | super().__init__() |
| | | self.normalize_length = normalize_length |
| | | self.criterion = criterion |
| | | |
| | | def forward(self, pred, label, lengths): |
| | | pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]) |
| | | loss = self.criterion(pred, label) |
| | | denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] |
| | | return loss.masked_fill(pad_mask, 0).sum() / denom |