fix bug in predictor tail_process_fn
| | |
| | | mask_2 = torch.cat([ones_t, mask], dim=1)
|
| | | mask = mask_2 - mask_1
|
| | | tail_threshold = mask * tail_threshold
|
| | | alphas = torch.cat([alphas, tail_threshold], dim=1)
|
| | | alphas = torch.cat([alphas, zeros_t], dim=1)
|
| | | alphas = torch.add(alphas, tail_threshold)
|
| | | else:
|
| | | tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
| | | tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
| | |
| | |
|
| | | predictor_alignments = index_div_bool_zeros_count_tile_out
|
| | | predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
|
| | | return predictor_alignments.detach(), predictor_alignments_length.detach() |
| | | return predictor_alignments.detach(), predictor_alignments_length.detach()
|