| | |
| | | feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) |
| | | _language_list = language_list[beg_idx:end_idx] |
| | | _textnorm_list = textnorm_list[beg_idx:end_idx] |
| | | if not len(_language_list): |
| | | _language_list = [language_list[0]] |
| | | _textnorm_list = [textnorm_list[0]] |
| | | B = feats.shape[0] |
| | | if len(_language_list) == 1 and B != 1: |
| | | _language_list = _language_list * B |
| | |
| | | torch.tensor(_language_list).to(self.device), |
| | | torch.tensor(_textnorm_list).to(self.device), |
| | | ) |
| | | # support batch_size=1 only currently |
| | | x = ctc_logits[0, : encoder_out_lens[0].item(), :] |
| | | yseq = x.argmax(dim=-1) |
| | | yseq = torch.unique_consecutive(yseq, dim=-1) |
| | | for b in range(feats.shape[0]): |
| | | # back to torch.Tensor |
| | | if isinstance(ctc_logits, np.ndarray): |
| | | ctc_logits = torch.from_numpy(ctc_logits).float() |
| | | # support batch_size=1 only currently |
| | | x = ctc_logits[b, : encoder_out_lens[b].item(), :] |
| | | yseq = x.argmax(dim=-1) |
| | | yseq = torch.unique_consecutive(yseq, dim=-1) |
| | | |
| | | mask = yseq != self.blank_id |
| | | token_int = yseq[mask].tolist() |
| | | mask = yseq != self.blank_id |
| | | token_int = yseq[mask].tolist() |
| | | |
| | | asr_res.append(self.tokenizer.decode(token_int)) |
| | | asr_res.append(self.tokenizer.decode(token_int)) |
| | | |
| | | return asr_res |
| | | |