From 7498bd7388afdde8d5e6f8a4cb6aeb8be8ac60fa Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 08 三月 2024 11:37:46 +0800
Subject: [PATCH] update code
---
funasr/losses/label_smoothing_loss.py | 8 ++++----
1 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 8f63df9..8f0809a 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -8,7 +8,7 @@
import torch
from torch import nn
-from funasr.modules.nets_utils import make_pad_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
class LabelSmoothingLoss(nn.Module):
@@ -75,10 +75,10 @@
self.criterion = criterion
def forward(self, pred, label, lengths):
- pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
+ pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
- return loss.masked_fill(pad_mask, 0).sum() / denom
+ return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
class NllLoss(nn.Module):
@@ -97,7 +97,7 @@
normalize_length=False,
criterion=nn.NLLLoss(reduction='none'),
):
- """Construct an LabelSmoothingLoss object."""
+ """Construct an NllLoss object."""
super(NllLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
--
Gitblit v1.9.1