| | |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.types import float_or_none |
| | | |
| | | import argparse |
| | | import logging |
| | | from pathlib import Path |
| | | import sys |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Any |
| | | from typing import List |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.tasks.punctuation import PunctuationTask |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.forward_adaptor import ForwardAdaptor |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime |
| | | |
| | | def inference_punc( |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | key_file: Optional[str], |
| | | train_config: Optional[str], |
| | | model_file: Optional[str], |
| | | output_dir: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | text2punc = Text2Punc(train_config, model_file, device) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type, |
| | | raw_inputs: Union[List[Any], bytes, str] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | cache: List[Any] = None, |
| | | param_dict: dict = None, |
| | | ): |
| | | results = [] |
| | | split_size = 20 |
| | | |
| | | if raw_inputs != None: |
| | | line = raw_inputs.strip() |
| | | key = "demo" |
| | | if line == "": |
| | | item = {'key': key, 'value': ""} |
| | | results.append(item) |
| | | return results |
| | | result, _ = text2punc(line) |
| | | item = {'key': key, 'value': result} |
| | | results.append(item) |
| | | return results |
| | | |
| | | for inference_text, _, _ in data_path_and_name_and_type: |
| | | with open(inference_text, "r", encoding="utf-8") as fin: |
| | | for line in fin: |
| | | line = line.strip() |
| | | segs = line.split("\t") |
| | | if len(segs) != 2: |
| | | continue |
| | | key = segs[0] |
| | | if len(segs[1]) == 0: |
| | | continue |
| | | result, _ = text2punc(segs[1]) |
| | | item = {'key': key, 'value': result} |
| | | results.append(item) |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path != None: |
| | | output_file_name = "infer.out" |
| | | Path(output_path).mkdir(parents=True, exist_ok=True) |
| | | output_file_path = (Path(output_path) / output_file_name).absolute() |
| | | with open(output_file_path, "w", encoding="utf-8") as fout: |
| | | for item_i in results: |
| | | key_out = item_i["key"] |
| | | value_out = item_i["value"] |
| | | fout.write(f"{key_out}\t{value_out}\n") |
| | | return results |
| | | |
| | | return _forward |
| | | |
| | | def inference_punc_vad_realtime( |
| | | batch_size: int, |
| | | dtype: str, |
| | | ngpu: int, |
| | | seed: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | #cache: list, |
| | | key_file: Optional[str], |
| | | train_config: Optional[str], |
| | | model_file: Optional[str], |
| | | output_dir: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | text2punc = Text2PuncVADRealtime(train_config, model_file, device) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type, |
| | | raw_inputs: Union[List[Any], bytes, str] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | cache: List[Any] = None, |
| | | param_dict: dict = None, |
| | | ): |
| | | results = [] |
| | | split_size = 10 |
| | | cache_in = param_dict["cache"] |
| | | if raw_inputs != None: |
| | | line = raw_inputs.strip() |
| | | key = "demo" |
| | | if line == "": |
| | | item = {'key': key, 'value': ""} |
| | | results.append(item) |
| | | return results |
| | | result, _, cache = text2punc(line, cache_in) |
| | | param_dict["cache"] = cache |
| | | item = {'key': key, 'value': result} |
| | | results.append(item) |
| | | return results |
| | | |
| | | return results |
| | | |
| | | return _forward |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "punc": |
| | | from funasr.bin.punctuation_infer import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | return inference_punc(**kwargs) |
| | | if mode == "punc_VadRealtime": |
| | | from funasr.bin.punctuation_infer_vadrealtime import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | return inference_punc_vad_realtime(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | |
| | | |
| | | kwargs.pop("gpuid_list", None) |
| | | kwargs.pop("njob", None) |
| | | results = inference_launch(**kwargs) |
| | | inference_pipeline = inference_launch(**kwargs) |
| | | return inference_pipeline(kwargs["data_path_and_name_and_type"]) |
| | | |
| | | |
| | | |
| | | if __name__ == "__main__": |