| | |
| | | assert check_argument_types() |
| | | # 1. Build ASR model |
| | | tp_model, tp_train_args = ASRTask.build_model_from_file( |
| | | timestamp_infer_config, timestamp_model_file, device |
| | | timestamp_infer_config, timestamp_model_file, device=device |
| | | ) |
| | | if 'cuda' in device: |
| | | tp_model = tp_model.cuda() # force model to cuda |
| | |
| | | enc = enc[0] |
| | | |
| | | # c. Forward Predictor |
| | | _, _, us_alphas, us_cif_peak = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1) |
| | | return us_alphas, us_cif_peak |
| | | _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1) |
| | | return us_alphas, us_peaks |
| | | |
| | | |
| | | def inference( |
| | |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if ngpu > 1: |