| | |
| | | hidden, alphas, token_num, mask=None
|
| | | )
|
| | |
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | |
| | | mask = mask.transpose(-1, -2).float()
|
| | | mask = mask.squeeze(-1)
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
| | | acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
|
| | |
|
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
|
| | |
| | | fires = fires + prefix_sum - prefix_sum_floor
|
| | |
|
| | | # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
|
| | |
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|
| | |
| | | # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
|
| | |
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|