speech_asr
2023-03-08 0892f5ce5240fde47fdcfc6f4faea8bfad6dc0ce
funasr/modules/eend_ola/utils/losses.py
@@ -8,19 +8,9 @@
def standard_loss(ys, ts, label_delay=0):
    losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
    loss = torch.sum(torch.stack(losses))
    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)  # 计算总的帧数
    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
    loss = loss / n_frames
    return loss
def batch_pit_loss(ys, ts, label_delay=0):
    loss_w_labels = [pit_loss(y, t)
                     for (y, t) in zip(ys, ts)]
    losses, labels = zip(*loss_w_labels)
    loss = torch.sum(torch.stack(losses))
    n_frames = torch.sum(torch.stack([t.shape[0] for t in ts]))
    loss = loss / n_frames
    return loss, labels
def batch_pit_n_speaker_loss(ys, ts, n_speakers_list):