游雁
2023-04-21 3cd3473bf7a3b41484baa86d9092248d78e7af39
funasr/bin/tp_inference.py
@@ -18,7 +18,7 @@
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.datasets.preprocessor import LMPreprocessor
from funasr.tasks.asr import ASRTaskAligner_temp as ASRTask
from funasr.tasks.asr import ASRTaskAligner as ASRTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
@@ -54,7 +54,7 @@
        assert check_argument_types()
        # 1. Build ASR model
        tp_model, tp_train_args = ASRTask.build_model_from_file(
            timestamp_infer_config, timestamp_model_file, device
            timestamp_infer_config, timestamp_model_file, device=device
        )
        if 'cuda' in device:
            tp_model = tp_model.cuda()  # force model to cuda
@@ -116,8 +116,8 @@
            enc = enc[0]
        # c. Forward Predictor
        _, _, us_alphas, us_cif_peak = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
        return us_alphas, us_cif_peak
        _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
        return us_alphas, us_peaks
def inference(
@@ -179,6 +179,9 @@
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1: