| | |
| | | def standard_loss(ys, ts, label_delay=0): |
| | | losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)] |
| | | loss = torch.sum(torch.stack(losses)) |
| | | n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device) # 计算总的帧数 |
| | | n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device) |
| | | loss = loss / n_frames |
| | | return loss |
| | | |
| | | |
| | | def batch_pit_loss(ys, ts, label_delay=0): |
| | | loss_w_labels = [pit_loss(y, t) |
| | | for (y, t) in zip(ys, ts)] |
| | | losses, labels = zip(*loss_w_labels) |
| | | loss = torch.sum(torch.stack(losses)) |
| | | n_frames = torch.sum(torch.stack([t.shape[0] for t in ts])) |
| | | loss = loss / n_frames |
| | | return loss, labels |
| | | |
| | | |
| | | def batch_pit_n_speaker_loss(ys, ts, n_speakers_list): |