From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/eend/utils/losses.py | 19 ++++++++++++++-----
1 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/funasr/models/eend/utils/losses.py b/funasr/models/eend/utils/losses.py
index 756952d..82fb573 100644
--- a/funasr/models/eend/utils/losses.py
+++ b/funasr/models/eend/utils/losses.py
@@ -7,7 +7,11 @@
def standard_loss(ys, ts):
losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
loss = torch.sum(torch.stack(losses))
- n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
+ n_frames = (
+ torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts])))
+ .to(torch.float32)
+ .to(ys[0].device)
+ )
loss = loss / n_frames
return loss
@@ -31,10 +35,15 @@
def cal_power_loss(logits, power_ts):
- losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
- zip(logits, power_ts)]
+ losses = [
+ F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit)
+ for logit, power_t in zip(logits, power_ts)
+ ]
loss = torch.sum(torch.stack(losses))
- n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
- power_ts[0].device)
+ n_frames = (
+ torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts])))
+ .to(torch.float32)
+ .to(power_ts[0].device)
+ )
loss = loss / n_frames
return loss
--
Gitblit v1.9.1