From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch
---
funasr/models/eend/utils/losses.py | 19 ++++++++++++++-----
1 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/funasr/models/eend/utils/losses.py b/funasr/models/eend/utils/losses.py
index 756952d..82fb573 100644
--- a/funasr/models/eend/utils/losses.py
+++ b/funasr/models/eend/utils/losses.py
@@ -7,7 +7,11 @@
def standard_loss(ys, ts):
losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
loss = torch.sum(torch.stack(losses))
- n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
+ n_frames = (
+ torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts])))
+ .to(torch.float32)
+ .to(ys[0].device)
+ )
loss = loss / n_frames
return loss
@@ -31,10 +35,15 @@
def cal_power_loss(logits, power_ts):
- losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
- zip(logits, power_ts)]
+ losses = [
+ F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit)
+ for logit, power_t in zip(logits, power_ts)
+ ]
loss = torch.sum(torch.stack(losses))
- n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
- power_ts[0].device)
+ n_frames = (
+ torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts])))
+ .to(torch.float32)
+ .to(power_ts[0].device)
+ )
loss = loss / n_frames
return loss
--
Gitblit v1.9.1