From d783b24ba7d8a03dabfa2139fcbf40c216e0ea3d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:34:52 +0800
Subject: [PATCH] Merge pull request #199 from alibaba-damo-academy/dev_xw
---
funasr/losses/label_smoothing_loss.py | 18 ++++++++++++++++++
1 files changed, 18 insertions(+), 0 deletions(-)
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 0d8b303..28df73f 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -8,6 +8,7 @@
import torch
from torch import nn
+from funasr.modules.nets_utils import make_pad_mask
class LabelSmoothingLoss(nn.Module):
@@ -61,3 +62,20 @@
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
+
+
+class SequenceBinaryCrossEntropy(nn.Module):
+ def __init__(
+ self,
+ normalize_length=False,
+ criterion=nn.BCEWithLogitsLoss(reduction="none")
+ ):
+ super().__init__()
+ self.normalize_length = normalize_length
+ self.criterion = criterion
+
+ def forward(self, pred, label, lengths):
+ pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
+ loss = self.criterion(pred, label)
+ denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
+ return loss.masked_fill(pad_mask, 0).sum() / denom
--
Gitblit v1.9.1