jmwang66
2023-06-20 2ff405b2f4ab899eff9bece232969fbb0c8f0555
funasr/bin/punc_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,55 +7,36 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
def inference_punc(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    logging.basicConfig(
@@ -73,11 +54,11 @@
    text2punc = Text2Punc(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 20
@@ -121,20 +102,21 @@
    return _forward
def inference_punc_vad_realtime(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    #cache: list,
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        # cache: list,
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -150,11 +132,11 @@
    text2punc = Text2PuncVADRealtime(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 10
@@ -177,7 +159,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "punc":
        return inference_punc(**kwargs)
@@ -186,6 +167,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -267,7 +249,6 @@
    kwargs.pop("njob", None)
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":