| | |
| | | hidden, alphas, token_num, mask=mask
|
| | | )
|
| | |
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
| | |
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | |
| | | hidden, alphas, token_num, mask=None
|
| | | )
|
| | |
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | |
| | | mask = mask.transpose(-1, -2).float()
|
| | | mask = mask.squeeze(-1)
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
| | | acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
|
| | |
|
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
|
| | |
| | | token_num_floor = torch.floor(token_num)
|
| | |
|
| | | return hidden, alphas, token_num_floor
|
| | | @torch.jit.script
|
| | | def cif_v1_export(hidden, alphas, threshold: float):
|
| | | device = hidden.device
|
| | | dtype = hidden.dtype
|
| | | batch_size, len_time, hidden_size = hidden.size()
|
| | | threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
| | |
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
| | |
|
| | | prefix_sum = torch.cumsum(alphas, dim=1)
|
| | | prefix_sum_floor = torch.floor(prefix_sum)
|
| | | dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
| | | dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
| | |
|
| | | dislocation_prefix_sum_floor[:, 0] = 0
|
| | | dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
|
| | |
|
| | | fire_idxs = dislocation_diff > 0
|
| | | fires[fire_idxs] = 1
|
| | | fires = fires + prefix_sum - prefix_sum_floor
|
| | |
|
| | | prefix_sum_hidden = torch.cumsum(
|
| | | alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
|
| | | )
|
| | |
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
|
| | | batch_len = fire_idxs.sum(1)
|
| | | batch_idxs = torch.cumsum(batch_len, dim=0)
|
| | | shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
|
| | | shift_batch_idxs[0] = 0
|
| | | shift_frames[shift_batch_idxs] = 0
|
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | remain_frames = (
|
| | | remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | )
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|
| | |
|
| | | frames = frames - shift_frames + shift_remain_frames - remain_frames
|
| | |
|
| | | max_label_len = batch_len.max()
|
| | |
|
| | | frame_fires = torch.zeros(
|
| | | batch_size, max_label_len, hidden_size, dtype=dtype, device=device
|
| | | )
|
| | | indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
| | | frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
| | | frame_fires[frame_fires_idxs] = frames
|
| | | return frame_fires, fires
|
| | |
|
| | | @torch.jit.script
|
| | | def cif_export(hidden, alphas, threshold: float):
|
| | |
| | | return torch.stack(list_ls, 0), fires
|
| | |
|
| | |
|
| | | def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
|
| | | batch_size, len_time = alphas.size()
|
| | | device = alphas.device
|
| | | dtype = alphas.dtype
|
| | |
|
| | | threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
|
| | |
|
| | | fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
|
| | |
|
| | | prefix_sum = torch.cumsum(alphas, dim=1)
|
| | | prefix_sum_floor = torch.floor(prefix_sum)
|
| | | dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
|
| | | dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
|
| | |
|
| | | dislocation_prefix_sum_floor[:, 0] = 0
|
| | | dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
|
| | |
|
| | | fire_idxs = dislocation_diff > 0
|
| | | fires[fire_idxs] = 1
|
| | | fires = fires + prefix_sum - prefix_sum_floor
|
| | | if return_fire_idxs:
|
| | | return fires, fire_idxs
|
| | | return fires
|
| | |
|
| | |
|
| | | def cif_v1(hidden, alphas, threshold):
|
| | | fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
|
| | |
|
| | | device = hidden.device
|
| | | dtype = hidden.dtype
|
| | | batch_size, len_time, hidden_size = hidden.size()
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size,
|
| | | dtype=dtype, device=device)
|
| | | prefix_sum_hidden = torch.cumsum(
|
| | | alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
|
| | | )
|
| | |
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
|
| | | batch_len = fire_idxs.sum(1)
|
| | | batch_idxs = torch.cumsum(batch_len, dim=0)
|
| | | shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
|
| | | shift_batch_idxs[0] = 0
|
| | | shift_frames[shift_batch_idxs] = 0
|
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | remain_frames = (
|
| | | remains[fire_idxs].unsqueeze(-1).tile((1,
|
| | | hidden_size)) * hidden[fire_idxs]
|
| | | )
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|
| | |
|
| | | frames = frames - shift_frames + shift_remain_frames - remain_frames
|
| | |
|
| | | max_label_len = batch_len.max()
|
| | |
|
| | | frame_fires = torch.zeros(
|
| | | batch_size, max_label_len, hidden_size, dtype=dtype, device=device
|
| | | )
|
| | | indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
|
| | | frame_fires_idxs = indices < batch_len.unsqueeze(1)
|
| | | frame_fires[frame_fires_idxs] = frames
|
| | | return frame_fires, fires
|
| | |
|
| | |
|
| | | def cif_wo_hidden(alphas, threshold):
|
| | | batch_size, len_time = alphas.size()
|
| | |
|