From 55c09aeaa25b4bb88a50e09ba68fa6ff00a6d676 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 20:10:39 +0800
Subject: [PATCH] update readme, fix seaco bug

---
 funasr/losses/label_smoothing_loss.py |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 3ea34c0..8f0809a 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -8,7 +8,7 @@
 
 import torch
 from torch import nn
-from funasr.modules.nets_utils import make_pad_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 
 class LabelSmoothingLoss(nn.Module):
@@ -75,10 +75,10 @@
         self.criterion = criterion
 
     def forward(self, pred, label, lengths):
-        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1])
+        pad_mask = make_pad_mask(lengths, maxlen=pred.shape[1]).to(pred.device)
         loss = self.criterion(pred, label)
         denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
-        return loss.masked_fill(pad_mask, 0).sum() / denom
+        return loss.masked_fill(pad_mask.unsqueeze(-1), 0).sum() / denom
 
 
 class NllLoss(nn.Module):

--
Gitblit v1.9.1