From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch

---
 funasr/models/eend/utils/losses.py |   19 ++++++++++++++-----
 1 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/funasr/models/eend/utils/losses.py b/funasr/models/eend/utils/losses.py
index 756952d..82fb573 100644
--- a/funasr/models/eend/utils/losses.py
+++ b/funasr/models/eend/utils/losses.py
@@ -7,7 +7,11 @@
 def standard_loss(ys, ts):
     losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
     loss = torch.sum(torch.stack(losses))
-    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
+    n_frames = (
+        torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts])))
+        .to(torch.float32)
+        .to(ys[0].device)
+    )
     loss = loss / n_frames
     return loss
 
@@ -31,10 +35,15 @@
 
 
 def cal_power_loss(logits, power_ts):
-    losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
-              zip(logits, power_ts)]
+    losses = [
+        F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit)
+        for logit, power_t in zip(logits, power_ts)
+    ]
     loss = torch.sum(torch.stack(losses))
-    n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
-        power_ts[0].device)
+    n_frames = (
+        torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts])))
+        .to(torch.float32)
+        .to(power_ts[0].device)
+    )
     loss = loss / n_frames
     return loss

--
Gitblit v1.9.1