| | |
| | | token_num_floor = torch.floor(token_num)
|
| | |
|
| | | return hidden, alphas, token_num_floor
|
| | |
|
| | |
|
| | | @torch.jit.script
|
| | | def cif_v1_export(hidden, alphas, threshold: float):
|
| | | device = hidden.device
|
| | |
| | | frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
| | |
|
| | | prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme |
| | | prefix_sum = torch.cumsum(alphas, dim=1)
|
| | | prefix_sum_floor = torch.floor(prefix_sum)
|
| | | dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
| | | dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
| | |
| | | fires[fire_idxs] = 1
|
| | | fires = fires + prefix_sum - prefix_sum_floor
|
| | |
|
| | | prefix_sum_hidden = torch.cumsum(
|
| | | alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
|
| | | )
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | |
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
| | | shift_frames[shift_batch_idxs] = 0
|
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | remain_frames = (
|
| | | remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | )
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|
| | |
|
| | | frames = frames - shift_frames + shift_remain_frames - remain_frames
|
| | |
|
| | | max_label_len = alphas.sum(dim=-1)
|
| | | max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
|
| | | max_label_len = batch_len.max()
|
| | |
|
| | | frame_fires = torch.zeros(
|
| | | batch_size, max_label_len, hidden_size, dtype=dtype, device=device
|
| | | )
|
| | | frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
|
| | | indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
| | | frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
| | | frame_fires[frame_fires_idxs] = frames
|
| | | return frame_fires, fires
|
| | |
|
| | |
|
| | | @torch.jit.script
|
| | | def cif_export(hidden, alphas, threshold: float):
|
| | |
| | |
|
| | | fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
| | |
|
| | | prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme |
| | | prefix_sum = torch.cumsum(alphas, dim=1)
|
| | | prefix_sum_floor = torch.floor(prefix_sum)
|
| | | dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
| | | dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
| | |
| | | device = hidden.device
|
| | | dtype = hidden.dtype
|
| | | batch_size, len_time, hidden_size = hidden.size()
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size,
|
| | | dtype=dtype, device=device)
|
| | | prefix_sum_hidden = torch.cumsum(
|
| | | alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
|
| | | )
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | |
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
| | | shift_frames[shift_batch_idxs] = 0
|
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | remain_frames = (
|
| | | remains[fire_idxs].unsqueeze(-1).tile((1,
|
| | | hidden_size)) * hidden[fire_idxs]
|
| | | )
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|
| | |
|
| | | frames = frames - shift_frames + shift_remain_frames - remain_frames
|
| | |
|
| | | max_label_len = torch.round(alphas.sum(-1)).int().max() # torch.round to calculate the max length
|
| | | max_label_len = batch_len.max()
|
| | |
|
| | | frame_fires = torch.zeros(
|
| | | batch_size, max_label_len, hidden_size, dtype=dtype, device=device
|
| | | )
|
| | | frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
|
| | | indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
| | | frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
| | | frame_fires[frame_fires_idxs] = frames
|