hnluo
2023-03-11 7231fa07672ba276a7df937e27e2b813e653dd21
funasr/bin/tp_inference.py
@@ -110,7 +110,7 @@
            timestamp_infer_config, timestamp_model_file, device
        )
        if 'cuda' in device:
            tp_model = tp_model.cuda()
            tp_model = tp_model.cuda()  # force model to cuda
        frontend = None
        if tp_train_args.frontend is not None:
@@ -263,7 +263,7 @@
    preprocessor = LMPreprocessor(
        train=False,
        token_type=speechtext2timestamp.tp_train_args.token_type,
        token_list=speechtext2timestamp.tp_train_args,
        token_list=speechtext2timestamp.tp_train_args.token_list,
        bpemodel=None,
        text_cleaner=None,
        g2p_type=None,
@@ -293,14 +293,11 @@
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
            preprocess_fn=LMPreprocessor,
            preprocess_fn=preprocessor,
            collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        tp_result_list = []
        for keys, batch in loader:
@@ -321,7 +318,6 @@
                ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token)
                logging.warning(ts_str)
                item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
                # tp_result_list.append({'text':"".join([i for i in token if i != '<sil>']), 'timestamp': ts_list})
                tp_result_list.append(item)
        return tp_result_list
@@ -407,6 +403,18 @@
        default=1,
        help="The batch size for inference",
    )
    group.add_argument(
        "--seg_dict_file",
        type=str,
        default=None,
        help="The batch size for inference",
    )
    group.add_argument(
        "--split_with_space",
        type=bool,
        default=False,
        help="The batch size for inference",
    )
    return parser