From 66880c2a1aeb3f94ce0020a71397e213beb9f3a0 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期二, 01 八月 2023 21:00:50 +0800
Subject: [PATCH] TOLD/SOND: update SequenceBinaryCrossEntropy loss
---
funasr/losses/label_smoothing_loss.py | 4 ++--
1 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 3ea34c0..c272ea8 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -75,10 +75,10 @@
self.criterion = criterion
def forward(self, pred, label, lengths):
- pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
+ pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
- return loss.masked_fill(pad_mask, 0).sum() / denom
+ return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
class NllLoss(nn.Module):
--
Gitblit v1.9.1