zhuangzhong
2024-06-07 c367f1b819bb40e70e379f8f80b1d9fd67e47c79
add cif_wo_hidden_v1
1个文件已修改
24 ■■■■ 已修改文件
funasr/models/paraformer/cif_predictor.py 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/cif_predictor.py
@@ -661,14 +661,13 @@
    return torch.stack(list_ls, 0), fires
def cif_v1(hidden, alphas, threshold):
def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
    batch_size, len_time = alphas.size()
    device = alphas.device
    dtype = alphas.dtype
    device = hidden.device
    dtype = hidden.dtype
    batch_size, len_time, hidden_size = hidden.size()
    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
    prefix_sum = torch.cumsum(alphas, dim=1)
@@ -682,7 +681,19 @@
    fire_idxs = dislocation_diff > 0
    fires[fire_idxs] = 1
    fires = fires + prefix_sum - prefix_sum_floor
    if return_fire_idxs:
        return fires, fire_idxs
    return fires
def cif_v1(hidden, alphas, threshold):
    fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
    device = hidden.device
    dtype = hidden.dtype
    batch_size, len_time, hidden_size = hidden.size()
    frames = torch.zeros(batch_size, len_time, hidden_size,
                         dtype=dtype, device=device)
    prefix_sum_hidden = torch.cumsum(
        alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
    )
@@ -698,7 +709,8 @@
    remains = fires - torch.floor(fires)
    remain_frames = (
        remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
        remains[fire_idxs].unsqueeze(-1).tile((1,
                                               hidden_size)) * hidden[fire_idxs]
    )
    shift_remain_frames = torch.roll(remain_frames, 1, dims=0)