| | |
| | | alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
| | | mask = mask.transpose(-1, -2).float()
|
| | | alphas = alphas * mask
|
| | | |
| | | alphas = alphas.squeeze(-1)
|
| | | |
| | | token_num = alphas.sum(-1)
|
| | |
|
| | | mask = mask.squeeze(-1)
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | |
|
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
| | |
|
| | | zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
| | | ones_t = torch.ones_like(zeros_t)
|
| | |
|
| | | mask_1 = torch.cat([mask, zeros_t], dim=1)
|
| | | mask_2 = torch.cat([ones_t, mask], dim=1)
|
| | | mask = mask_2 - mask_1
|
| | | tail_threshold = mask * tail_threshold
|
| | | alphas = torch.cat([alphas, tail_threshold], dim=1)
|
| | | |
| | | alphas = torch.cat([alphas, zeros_t], dim=1)
|
| | | alphas = torch.add(alphas, tail_threshold)
|
| | |
|
| | | zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
| | | hidden = torch.cat([hidden, zeros], dim=1)
|
| | | token_num = alphas.sum(dim=-1)
|