From 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 18:03:06 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_aky

---
 funasr/modules/eend_ola/utils/losses.py |   12 +-----------
 1 files changed, 1 insertions(+), 11 deletions(-)

diff --git a/funasr/modules/eend_ola/utils/losses.py b/funasr/modules/eend_ola/utils/losses.py
index 97443bc..af0181d 100644
--- a/funasr/modules/eend_ola/utils/losses.py
+++ b/funasr/modules/eend_ola/utils/losses.py
@@ -8,19 +8,9 @@
 def standard_loss(ys, ts, label_delay=0):
     losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
     loss = torch.sum(torch.stack(losses))
-    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)  # 璁$畻鎬荤殑甯ф暟
+    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
     loss = loss / n_frames
     return loss
-
-
-def batch_pit_loss(ys, ts, label_delay=0):
-    loss_w_labels = [pit_loss(y, t)
-                     for (y, t) in zip(ys, ts)]
-    losses, labels = zip(*loss_w_labels)
-    loss = torch.sum(torch.stack(losses))
-    n_frames = torch.sum(torch.stack([t.shape[0] for t in ts]))
-    loss = loss / n_frames
-    return loss, labels
 
 
 def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):

--
Gitblit v1.9.1