liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/eend/utils/losses.py
@@ -7,7 +7,11 @@
def standard_loss(ys, ts):
    losses = [F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts)]
    loss = torch.sum(torch.stack(losses))
    n_frames = torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))).to(torch.float32).to(ys[0].device)
    n_frames = (
        torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts])))
        .to(torch.float32)
        .to(ys[0].device)
    )
    loss = loss / n_frames
    return loss
@@ -31,10 +35,15 @@
def cal_power_loss(logits, power_ts):
    losses = [F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) for logit, power_t in
              zip(logits, power_ts)]
    losses = [
        F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit)
        for logit, power_t in zip(logits, power_ts)
    ]
    loss = torch.sum(torch.stack(losses))
    n_frames = torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))).to(torch.float32).to(
        power_ts[0].device)
    n_frames = (
        torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts])))
        .to(torch.float32)
        .to(power_ts[0].device)
    )
    loss = loss / n_frames
    return loss