嘉渊
2023-06-15 ca6b2e29fd3b0e9bfc4325d266a6416ff5a0d252
update repo
2个文件已修改
165 ■■■■■ 已修改文件
funasr/bin/punc_infer.py 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_inference_launch.py 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_infer.py
@@ -1,46 +1,32 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
    def __init__(
        self,
        train_config: Optional[str],
        model_file: Optional[str],
        device: str = "cpu",
        dtype: str = "float32",
            self,
            train_config: Optional[str],
            model_file: Optional[str],
            device: str = "cpu",
            dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -144,16 +130,16 @@
class Text2PuncVADRealtime:
    def __init__(
        self,
        train_config: Optional[str],
        model_file: Optional[str],
        device: str = "cpu",
        dtype: str = "float32",
            self,
            train_config: Optional[str],
            model_file: Optional[str],
            device: str = "cpu",
            dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        model, train_args = build_model_from_file(train_config, model_file, None, device, task_name="punc")
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
@@ -178,7 +164,7 @@
            text_name="text",
            non_linguistic_symbols=train_args.non_linguistic_symbols,
        )
    @torch.no_grad()
    def __call__(self, text: Union[list, str], cache: list, split_size=20):
        if cache is not None and len(cache) > 0:
@@ -215,7 +201,7 @@
            if indices.size()[0] != 1:
                punctuations = torch.squeeze(indices)
            assert punctuations.size()[0] == len(mini_sentence)
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
@@ -226,7 +212,7 @@
                        break
                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
@@ -235,11 +221,11 @@
                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            punctuations_np = punctuations.cpu().numpy()
            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
            sentence_words_list += mini_sentence
        assert len(sentence_punc_list) == len(sentence_words_list)
        words_with_punc = []
        sentence_punc_list_out = []
@@ -256,7 +242,7 @@
                if sentence_punc_list[i] != "_":
                    words_with_punc.append(sentence_punc_list[i])
        sentence_out = "".join(words_with_punc)
        sentenceEnd = -1
        for i in range(len(sentence_punc_list) - 2, 1, -1):
            if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
@@ -267,5 +253,3 @@
            sentence_out = sentence_out[:-1]
            sentence_punc_list_out[-1] = "_"
        return sentence_out, sentence_punc_list_out, cache_out
funasr/bin/punc_inference_launch.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
@@ -7,55 +7,36 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.types import float_or_none
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.datasets.preprocessor import split_to_mini_sentence
from funasr.bin.punc_infer import Text2Punc, Text2PuncVADRealtime
def inference_punc(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    logging.basicConfig(
@@ -73,11 +54,11 @@
    text2punc = Text2Punc(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 20
@@ -121,20 +102,21 @@
    return _forward
def inference_punc_vad_realtime(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    #cache: list,
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
        batch_size: int,
        dtype: str,
        ngpu: int,
        seed: int,
        num_workers: int,
        log_level: Union[int, str],
        # cache: list,
        key_file: Optional[str],
        train_config: Optional[str],
        model_file: Optional[str],
        output_dir: Optional[str] = None,
        param_dict: dict = None,
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
@@ -150,11 +132,11 @@
    text2punc = Text2PuncVADRealtime(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
            data_path_and_name_and_type,
            raw_inputs: Union[List[Any], bytes, str] = None,
            output_dir_v2: Optional[str] = None,
            cache: List[Any] = None,
            param_dict: dict = None,
    ):
        results = []
        split_size = 10
@@ -177,7 +159,6 @@
    return _forward
def inference_launch(mode, **kwargs):
    if mode == "punc":
        return inference_punc(**kwargs)
@@ -186,6 +167,7 @@
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
@@ -267,7 +249,6 @@
    kwargs.pop("njob", None)
    inference_pipeline = inference_launch(**kwargs)
    return inference_pipeline(kwargs["data_path_and_name_and_type"])
if __name__ == "__main__":