| | |
| | | self, |
| | | model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | plot_timestamp_to: str = "", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4, |
| | | cache_dir: str = None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | self.device = kwargs.get("device", "cpu") |
| | | if not Path(model_dir).exists(): |
| | | try: |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | |
| | | end_idx = min(waveform_nums, beg_idx + self.batch_size) |
| | | feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) |
| | | ctc_logits, encoder_out_lens = self.ort_infer( |
| | | torch.Tensor(feats), |
| | | torch.Tensor(feats_len), |
| | | torch.tensor([language]), |
| | | torch.tensor([textnorm]), |
| | | torch.Tensor(feats).to(self.device), |
| | | torch.Tensor(feats_len).to(self.device), |
| | | torch.tensor([language]).to(self.device), |
| | | torch.tensor([textnorm]).to(self.device), |
| | | ) |
| | | # support batch_size=1 only currently |
| | | x = ctc_logits[0, : encoder_out_lens[0].item(), :] |