rollback cif_v1 for training bug
| | |
| | | hidden, alphas, token_num, mask=mask
|
| | | )
|
| | |
|
| | | acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | |
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | |
| | | hidden, alphas, token_num, mask=None
|
| | | )
|
| | |
|
| | | acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|