| | |
| | | import time |
| | | import copy |
| | | import os |
| | | import re |
| | | import codecs |
| | | import tempfile |
| | | import requests |
| | |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | else: |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device |
| | | ) |
| | | if pre_token_length[i] == 0: |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device |
| | | ) |
| | | score = torch.tensor(0.0, device=yseq.device) |
| | | else: |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device |
| | | ) |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | |
| | | for hyp in nbest_hyps: |
| | | assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | |
| | |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | token = " ".join(token) |
| | | |
| | | results.append(token) |
| | | postprocessed_result = "" |
| | | for item in token: |
| | | if item.endswith('@@'): |
| | | postprocessed_result += item[:-2] |
| | | elif re.match('^[a-zA-Z]+$', item): |
| | | postprocessed_result += item + " " |
| | | else: |
| | | postprocessed_result += item |
| | | |
| | | results.append(postprocessed_result) |
| | | |
| | | # assert check_return_type(results) |
| | | return results |