Merge pull request #4 from TeaPoly/main
Fix some issue to make batch inference easy for predictor and decoder.
| | |
| | | |
| | | predictor_outs = self.asr_model.calc_predictor(enc, enc_len) |
| | | pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1] |
| | | pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device) |
| | | pre_token_length = pre_token_length.long() |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | |
| | | b, t, d = hidden.size()
|
| | | tail_threshold = self.tail_threshold
|
| | | tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
| | | tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
| | | tail_threshold = tail_threshold.unsqueeze(0).repeat(b, 1)
|
| | | alphas = torch.cat([alphas, tail_threshold], dim=1)
|
| | | zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
| | | hidden = torch.cat([hidden, zeros], dim=1)
|