| | |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | else: |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device |
| | | ) |
| | | if pre_token_length[i] == 0: |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device |
| | | ) |
| | | score = torch.tensor(0.0, device=yseq.device) |
| | | else: |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device |
| | | ) |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | |
| | | for hyp in nbest_hyps: |
| | | assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | |
| | |
| | | feats = cache_en["feats"] |
| | | feats_len = torch.tensor([feats.shape[1]]) |
| | | self.asr_model.frontend = None |
| | | self.frontend.cache_reset() |
| | | results = self.infer(feats, feats_len, cache) |
| | | return results |
| | | else: |
| | | if self.frontend is not None: |
| | | if cache_en["start_idx"] == 0: |
| | | self.frontend.cache_reset() |
| | | feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"]) |
| | | feats = to_device(feats, device=self.device) |
| | | feats_len = feats_len.int() |
| | |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2Text(**kwargs) |
| | | return Speech2TextTransducer(**kwargs) |
| | | |
| | | |
| | | class Speech2TextSAASR: |