游雁
2023-07-03 4ee715e70e36cdba7b05fe044fecab9cf4fa16ff
funasr/bin/punc_inference_launch.py
@@ -1,19 +1,169 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
import os
import sys
from typing import Union, Dict, Any
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import torch
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
def inference_punc(
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    text2punc = Text2Punc(train_config, model_file, device)
    def _forward(
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 20
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "demo"
            if line == "":
                item = {'key': key, 'value': ""}
                results.append(item)
                return results
            result, _ = text2punc(line)
            item = {'key': key, 'value': result}
            results.append(item)
            return results
        for inference_text, _, _ in data_path_and_name_and_type:
            with open(inference_text, "r", encoding="utf-8") as fin:
                for line in fin:
                    line = line.strip()
                    segs = line.split("\t")
                    if len(segs) != 2:
                        continue
                    key = segs[0]
                    if len(segs[1]) == 0:
                        continue
                    result, _ = text2punc(segs[1])
                    item = {'key': key, 'value': result}
                    results.append(item)
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path != None:
            output_file_name = "infer.out"
            Path(output_path).mkdir(parents=True, exist_ok=True)
            output_file_path = (Path(output_path) / output_file_name).absolute()
            with open(output_file_path, "w", encoding="utf-8") as fout:
                for item_i in results:
                    key_out = item_i["key"]
                    value_out = item_i["value"]
                    fout.write(f"{key_out}\t{value_out}\n")
        return results
    return _forward
def inference_punc_vad_realtime(
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        # cache: list,
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    text2punc = Text2PuncVADRealtime(train_config, model_file, device)
    def _forward(
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 10
        cache_in = param_dict["cache"]
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "demo"
            if line == "":
                item = {'key': key, 'value': ""}
                results.append(item)
                return results
            result, _, cache = text2punc(line, cache_in)
            param_dict["cache"] = cache
            item = {'key': key, 'value': result}
            results.append(item)
            return results
        return results
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "punc":
        return inference_punc(**kwargs)
    if mode == "punc_VadRealtime":
        return inference_punc_vad_realtime(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
@@ -59,33 +209,16 @@
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        action="append",
        required=False
    )
    group.add_argument(
        "--raw_inputs",
        type=str,
        required=False
    )
    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
    group.add_argument("--raw_inputs", type=str, required=False)
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--cache", type=list, required=False)
    group.add_argument("--param_dict", type=dict, required=False)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument("--train_config", type=str)
    group.add_argument("--model_file", type=str)
    group.add_argument("--mode", type=str, default="punc")
    return parser
def inference_launch(mode, **kwargs):
    if mode == "punc":
        from funasr.bin.punctuation_infer import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
@@ -111,7 +244,8 @@
    kwargs.pop("gpuid_list", None)
    kwargs.pop("njob", None)
    results = inference_launch(**kwargs)
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":