Daniel
2023-03-08 7b639890f21312218138f63a1ea74b5e55d38bd8
Merge branch 'alibaba-damo-academy:main' into main
2个文件已修改
2个文件已添加
366 ■■■■■ 已修改文件
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punc_inference_launch.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer_vadrealtime.py 335 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
New file
@@ -0,0 +1,26 @@
##################text二进制数据#####################
inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
inference_pipline = pipeline(
    task=Tasks.punctuation,
    model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
    model_revision="v1.0.0",
    output_dir="./tmp/"
)
vads = inputs.split("|")
cache_out = []
rec_result_all="outputs:"
for vad in vads:
    rec_result = inference_pipline(text_in=vad, cache=cache_out)
    #print(rec_result)
    cache_out = rec_result['cache']
    rec_result_all += rec_result['text']
print(rec_result_all)
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
@@ -15,7 +15,7 @@
inference_pipline = pipeline(
    task=Tasks.punctuation,
    model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
    model_revision="v1.1.6",
    model_revision="v1.1.7",
    output_dir="./tmp/"
)
funasr/bin/punc_inference_launch.py
@@ -75,6 +75,9 @@
    if mode == "punc":
        from funasr.bin.punctuation_infer import inference_modelscope
        return inference_modelscope(**kwargs)
    if mode == "punc_VadRealtime":
        from funasr.bin.punctuation_infer_vadrealtime import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
funasr/bin/punctuation_infer_vadrealtime.py
New file
@@ -0,0 +1,335 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Any
from typing import List
import numpy as np
import torch
from typeguard import check_argument_types
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
class Text2Punc:
    def __init__(
        self,
        train_config: Optional[str],
        model_file: Optional[str],
        device: str = "cpu",
        dtype: str = "float32",
    ):
        #  Build Model
        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
        self.device = device
        # Wrape model to make model.nll() data-parallel
        self.wrapped_model = ForwardAdaptor(model, "inference")
        self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
        # logging.info(f"Model:\n{model}")
        self.punc_list = train_args.punc_list
        self.period = 0
        for i in range(len(self.punc_list)):
            if self.punc_list[i] == ",":
                self.punc_list[i] = ","
            elif self.punc_list[i] == "?":
                self.punc_list[i] = "?"
            elif self.punc_list[i] == "。":
                self.period = i
        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
            train=False,
            token_type=train_args.token_type,
            token_list=train_args.token_list,
            bpemodel=train_args.bpemodel,
            text_cleaner=train_args.cleaner,
            g2p_type=train_args.g2p,
            text_name="text",
            non_linguistic_symbols=train_args.non_linguistic_symbols,
        )
        print("start decoding!!!")
    @torch.no_grad()
    def __call__(self, text: Union[list, str], cache: list, split_size=20):
        if cache is not None and len(cache) > 0:
            precache = "".join(cache)
        else:
            precache = ""
        data = {"text": precache + text}
        result = self.preprocessor(data=data, uid="12938712838719")
        split_text = self.preprocessor.pop_split_text_data(result)
        mini_sentences = split_to_mini_sentence(split_text, split_size)
        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
        assert len(mini_sentences) == len(mini_sentences_id)
        cache_sent = []
        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
        sentence_punc_list = []
        sentence_words_list= []
        cache_pop_trigger_limit = 200
        skip_num = 0
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
            data = {
                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
                "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
            }
            data = to_device(data, self.device)
            y, _ = self.wrapped_model(**data)
            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
            punctuations = indices
            if indices.size()[0] != 1:
                punctuations = torch.squeeze(indices)
            assert punctuations.size()[0] == len(mini_sentence)
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
                last_comma_index = -1
                for i in range(len(punctuations) - 2, 1, -1):
                    if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
                        sentenceEnd = i
                        break
                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
                    punctuations[sentenceEnd] = self.period
                cache_sent = mini_sentence[sentenceEnd + 1:]
                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
            punctuations_np = punctuations.cpu().numpy()
            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
            sentence_words_list += mini_sentence
        assert len(sentence_punc_list) == len(sentence_words_list)
        words_with_punc = []
        sentence_punc_list_out = []
        for i in range(0, len(sentence_words_list)):
            if i > 0:
                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
                    sentence_words_list[i] = " " + sentence_words_list[i]
            if skip_num < len(cache):
                skip_num += 1
            else:
                words_with_punc.append(sentence_words_list[i])
            if skip_num >= len(cache):
                sentence_punc_list_out.append(sentence_punc_list[i])
                if sentence_punc_list[i] != "_":
                    words_with_punc.append(sentence_punc_list[i])
        sentence_out = "".join(words_with_punc)
        sentenceEnd = -1
        for i in range(len(sentence_punc_list) - 2, 1, -1):
            if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
               sentenceEnd = i
               break
        cache_out = sentence_words_list[sentenceEnd + 1 :]
        if sentence_out[-1] in self.punc_list:
            sentence_out = sentence_out[:-1]
            sentence_punc_list_out[-1] = "_"
        return sentence_out, sentence_punc_list_out, cache_out
def inference(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    output_dir: str,
    log_level: Union[int, str],
    train_config: Optional[str],
    model_file: Optional[str],
    key_file: Optional[str] = None,
    data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
    raw_inputs: Union[List[Any], bytes, str] = None,
    cache: List[Any] = None,
    param_dict: dict = None,
    **kwargs,
):
    inference_pipeline = inference_modelscope(
        output_dir=output_dir,
        batch_size=batch_size,
        dtype=dtype,
        ngpu=ngpu,
        seed=seed,
        num_workers=num_workers,
        log_level=log_level,
        key_file=key_file,
        train_config=train_config,
        model_file=model_file,
        param_dict=param_dict,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs, cache)
def inference_modelscope(
    batch_size: int,
    dtype: str,
    ngpu: int,
    seed: int,
    num_workers: int,
    log_level: Union[int, str],
    #cache: list,
    key_file: Optional[str],
    train_config: Optional[str],
    model_file: Optional[str],
    output_dir: Optional[str] = None,
    param_dict: dict = None,
    **kwargs,
):
    assert check_argument_types()
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    text2punc = Text2Punc(train_config, model_file, device)
    def _forward(
        data_path_and_name_and_type,
        raw_inputs: Union[List[Any], bytes, str] = None,
        output_dir_v2: Optional[str] = None,
        cache: List[Any] = None,
        param_dict: dict = None,
    ):
        results = []
        split_size = 10
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "demo"
            if line == "":
                item = {'key': key, 'value': ""}
                results.append(item)
                return results
            #import pdb;pdb.set_trace()
            result, _, cache = text2punc(line, cache)
            item = {'key': key, 'value': result, 'cache': cache}
            results.append(item)
            return results
        for inference_text, _, _ in data_path_and_name_and_type:
            with open(inference_text, "r", encoding="utf-8") as fin:
                for line in fin:
                    line = line.strip()
                    segs = line.split("\t")
                    if len(segs) != 2:
                        continue
                    key = segs[0]
                    if len(segs[1]) == 0:
                        continue
                    result, _ = text2punc(segs[1])
                    item = {'key': key, 'value': result}
                    results.append(item)
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path != None:
            output_file_name = "infer.out"
            Path(output_path).mkdir(parents=True, exist_ok=True)
            output_file_path = (Path(output_path) / output_file_name).absolute()
            with open(output_file_path, "w", encoding="utf-8") as fout:
                for item_i in results:
                    key_out = item_i["key"]
                    value_out = item_i["value"]
                    fout.write(f"{key_out}\t{value_out}\n")
        return results
    return _forward
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="Punctuation inference",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=False)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=0,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
    group.add_argument("--raw_inputs", type=str, required=False)
    group.add_argument("--cache", type=list, required=False)
    group.add_argument("--param_dict", type=dict, required=False)
    group.add_argument("--key_file", type=str_or_none)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument("--train_config", type=str)
    group.add_argument("--model_file", type=str)
    return parser
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    # kwargs.pop("config", None)
    inference(**kwargs)
if __name__ == "__main__":
    main()