Legend
2024-12-15 5e7a8d1ccae80e54f2e2ecfffdf8e4294800b5c3
funasr/models/paraformer/cif_predictor.py
@@ -80,7 +80,7 @@
                    hidden, alphas, token_num, mask=mask
                )
            acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
            acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
            if target_length is None and self.tail_threshold > 0.0:
                token_num_int = torch.max(token_num).type(torch.int32).item()
@@ -506,7 +506,10 @@
    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
    prefix_sum = torch.cumsum(alphas, dim=1)
    # prefix_sum = torch.cumsum(alphas, dim=1)
    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
        torch.float32
    )  # cumsum precision degradation cause wrong result in extreme
    prefix_sum_floor = torch.floor(prefix_sum)
    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -518,8 +521,8 @@
    fires[fire_idxs] = 1
    fires = fires + prefix_sum - prefix_sum_floor
    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
    # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
    frames = prefix_sum_hidden[fire_idxs]
    shift_frames = torch.roll(frames, 1, dims=0)
@@ -530,15 +533,19 @@
    shift_frames[shift_batch_idxs] = 0
    remains = fires - torch.floor(fires)
    remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
    # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
    remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
    shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
    shift_remain_frames[shift_batch_idxs] = 0
    frames = frames - shift_frames + shift_remain_frames - remain_frames
    max_label_len = batch_len.max()
    # max_label_len = batch_len.max()
    max_label_len = alphas.sum(dim=-1)
    max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
    # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
    frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
    indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
    frame_fires_idxs = indices < batch_len.unsqueeze(1)
@@ -667,7 +674,10 @@
    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
    prefix_sum = torch.cumsum(alphas, dim=1)
    # prefix_sum = torch.cumsum(alphas, dim=1)
    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
        torch.float32
    )  # cumsum precision degradation cause wrong result in extreme
    prefix_sum_floor = torch.floor(prefix_sum)
    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -689,8 +699,10 @@
    device = hidden.device
    dtype = hidden.dtype
    batch_size, len_time, hidden_size = hidden.size()
    # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
    # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
    frames = prefix_sum_hidden[fire_idxs]
    shift_frames = torch.roll(frames, 1, dims=0)
@@ -702,15 +714,20 @@
    shift_frames[shift_batch_idxs] = 0
    remains = fires - torch.floor(fires)
    remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
    # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
    remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
    shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
    shift_remain_frames[shift_batch_idxs] = 0
    frames = frames - shift_frames + shift_remain_frames - remain_frames
    max_label_len = batch_len.max()
    # max_label_len = batch_len.max()
    max_label_len = (
        torch.round(alphas.sum(-1)).int().max()
    )  # torch.round to calculate the max length
    # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
    frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
    indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
    frame_fires_idxs = indices < batch_len.unsqueeze(1)