| funasr/bin/punc_infer.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/bin/punc_inference_launch.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/bin/punc_infer.py
@@ -1,33 +1,19 @@ # -*- encoding: utf-8 -*- #!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import argparse import logging from pathlib import Path import sys from typing import Optional from typing import Sequence from typing import Tuple from typing import Union from typing import Any from typing import List import numpy as np import torch from typeguard import check_argument_types from funasr.build_utils.build_model_from_file import build_model_from_file from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor from funasr.utils.cli_utils import get_commandline_args from funasr.tasks.punctuation import PunctuationTask from funasr.datasets.preprocessor import split_to_mini_sentence from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.forward_adaptor import ForwardAdaptor from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.datasets.preprocessor import split_to_mini_sentence class Text2Punc: @@ -40,7 +26,7 @@ dtype: str = "float32", ): # Build Model model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device) model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc") self.device = device # Wrape model to make model.nll() data-parallel self.wrapped_model = ForwardAdaptor(model, "inference") @@ -153,7 +139,7 @@ dtype: str = "float32", ): # Build Model model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device) model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc") self.device = device # Wrape model to make model.nll() data-parallel self.wrapped_model = ForwardAdaptor(model, "inference") @@ -267,5 +253,3 @@ sentence_out = sentence_out[:-1] sentence_punc_list_out[-1] = "_" return sentence_out, sentence_punc_list_out, cache_out funasr/bin/punc_inference_launch.py
@@ -1,5 +1,5 @@ # -*- encoding: utf-8 -*- #!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) @@ -7,41 +7,22 @@ import logging import os import sys from typing import Union, Dict, Any from funasr.utils import config_argparse from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils.types import float_or_none import argparse import logging from pathlib import Path import sys from typing import Optional from typing import Sequence from typing import Tuple from typing import Union from typing import Any from typing import List from typing import Optional from typing import Union import numpy as np import torch from typeguard import check_argument_types from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor from funasr.utils.cli_utils import get_commandline_args from funasr.tasks.punctuation import PunctuationTask from funasr.torch_utils.device_funcs import to_device from funasr.torch_utils.forward_adaptor import ForwardAdaptor from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime from funasr.torch_utils.set_all_random_seed import set_all_random_seed from funasr.utils import config_argparse from funasr.utils.cli_utils import get_commandline_args from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.datasets.preprocessor import split_to_mini_sentence from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime def inference_punc( batch_size: int, @@ -121,6 +102,7 @@ return _forward def inference_punc_vad_realtime( batch_size: int, dtype: str, @@ -177,7 +159,6 @@ return _forward def inference_launch(mode, **kwargs): if mode == "punc": return inference_punc(**kwargs) @@ -186,6 +167,7 @@ else: logging.info("Unknown decoding mode: {}".format(mode)) return None def get_parser(): parser = config_argparse.ArgumentParser( @@ -267,7 +249,6 @@ kwargs.pop("njob", None) inference_pipeline = inference_launch(**kwargs) return inference_pipeline(kwargs["data_path_and_name_and_type"]) if __name__ == "__main__":