jmwang66
2023-06-20 2ff405b2f4ab899eff9bece232969fbb0c8f0555
funasr/bin/punc_infer.py
@@ -1,33 +1,19 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
@@ -40,7 +26,7 @@
        dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -153,7 +139,7 @@
        dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -267,5 +253,3 @@
            sentence_out = sentence_out[:-1]
            sentence_punc_list_out[-1] = "_"
        return sentence_out, sentence_punc_list_out, cache_out