From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/eend/utils/losses.py |   19 ++++++++++++++-----
 1 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/funasr/models/eend/utils/losses.py b/funasr/models/eend/utils/losses.py
index 756952d..82fb573 100644
--- a/funasr/models/eend/utils/losses.py
+++ b/funasr/models/eend/utils/losses.py
@@ -7,7 +7,11 @@
 def standard_loss(ys, ts):
     losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
     loss = torch.sum(torch.stack(losses))
-    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
+    n_frames = (
+        torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts])))
+        .to(torch.float32)
+        .to(ys[0].device)
+    )
     loss = loss / n_frames
     return loss
 
@@ -31,10 +35,15 @@
 
 
 def cal_power_loss(logits, power_ts):
-    losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
-              zip(logits, power_ts)]
+    losses = [
+        F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit)
+        for logit, power_t in zip(logits, power_ts)
+    ]
     loss = torch.sum(torch.stack(losses))
-    n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
-        power_ts[0].device)
+    n_frames = (
+        torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts])))
+        .to(torch.float32)
+        .to(power_ts[0].device)
+    )
     loss = loss / n_frames
     return loss

--
Gitblit v1.9.1