| | |
| | | def forward(self, hidden: torch.Tensor,
|
| | | mask: torch.Tensor,
|
| | | ):
|
| | | alphas, token_num = self.forward_cnn(hidden, mask)
|
| | | mask = mask.transpose(-1, -2).float()
|
| | | mask = mask.squeeze(-1)
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | |
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | | |
| | | def forward_cnn(self, hidden: torch.Tensor,
|
| | | mask: torch.Tensor,
|
| | | ):
|
| | | h = hidden
|
| | | context = h.transpose(1, 2)
|
| | | queries = self.pad(context)
|
| | |
| | | alphas = alphas * mask
|
| | | alphas = alphas.squeeze(-1)
|
| | | token_num = alphas.sum(-1)
|
| | | |
| | | mask = mask.squeeze(-1)
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | |
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
|
| | | return alphas, token_num
|
| | |
|
| | | def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
| | | b, t, d = hidden.size()
|
| | |
| | | integrate)
|
| | |
|
| | | fires = torch.stack(list_fires, 1)
|
| | | return fires |
| | | return fires
|