| | |
| | | return torch.stack(list_ls, 0), fires
|
| | |
|
| | |
|
| | | def cif_v1(hidden, alphas, threshold):
|
| | | def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
|
| | | batch_size, len_time = alphas.size()
|
| | | device = alphas.device
|
| | | dtype = alphas.dtype
|
| | |
|
| | | device = hidden.device
|
| | | dtype = hidden.dtype
|
| | | batch_size, len_time, hidden_size = hidden.size()
|
| | | threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
| | |
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
| | |
|
| | | prefix_sum = torch.cumsum(alphas, dim=1)
|
| | |
| | | fire_idxs = dislocation_diff > 0
|
| | | fires[fire_idxs] = 1
|
| | | fires = fires + prefix_sum - prefix_sum_floor
|
| | | if return_fire_idxs:
|
| | | return fires, fire_idxs
|
| | | return fires
|
| | |
|
| | |
|
| | | def cif_v1(hidden, alphas, threshold):
|
| | | fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
|
| | |
|
| | | device = hidden.device
|
| | | dtype = hidden.dtype
|
| | | batch_size, len_time, hidden_size = hidden.size()
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size,
|
| | | dtype=dtype, device=device)
|
| | | prefix_sum_hidden = torch.cumsum(
|
| | | alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
|
| | | )
|
| | |
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | remain_frames = (
|
| | | remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remains[fire_idxs].unsqueeze(-1).tile((1,
|
| | | hidden_size)) * hidden[fire_idxs]
|
| | | )
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|