From 2a05f164f5d3857739be6838e6e01342259d6779 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 三月 2023 18:42:10 +0800
Subject: [PATCH] Merge pull request #192 from alibaba-damo-academy/dev_cmz

---
 funasr/bin/punc_inference_launch.py                                                          |    3 
 funasr/bin/punctuation_infer_vadrealtime.py                                                  |  335 +++++++++++++++++++++++++++++++++++++++++++++++
 egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py     |    2 
 egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py |   26 +++
 4 files changed, 365 insertions(+), 1 deletions(-)

diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
new file mode 100644
index 0000000..02859c2
--- /dev/null
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
@@ -0,0 +1,26 @@
+
+##################text浜岃繘鍒舵暟鎹�#####################
+inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipline = pipeline(
+    task=Tasks.punctuation,
+    model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
+    model_revision="v1.0.0",
+    output_dir="./tmp/"
+)
+
+vads = inputs.split("|")
+
+cache_out = []
+rec_result_all="outputs:"
+for vad in vads:
+    rec_result = inference_pipline(text_in=vad, cache=cache_out)
+    #print(rec_result)
+    cache_out = rec_result['cache']
+    rec_result_all += rec_result['text']
+
+print(rec_result_all)
+
diff --git a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
index 8dac292..0da8d25 100644
--- a/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
+++ b/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/infer.py
@@ -15,7 +15,7 @@
 inference_pipline = pipeline(
     task=Tasks.punctuation,
     model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
-    model_revision="v1.1.6",
+    model_revision="v1.1.7",
     output_dir="./tmp/"
 )
 
diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py
index 53db1df..e7e3f15 100755
--- a/funasr/bin/punc_inference_launch.py
+++ b/funasr/bin/punc_inference_launch.py
@@ -75,6 +75,9 @@
     if mode == "punc":
         from funasr.bin.punctuation_infer import inference_modelscope
         return inference_modelscope(**kwargs)
+    if mode == "punc_VadRealtime":
+        from funasr.bin.punctuation_infer_vadrealtime import inference_modelscope
+        return inference_modelscope(**kwargs)
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
new file mode 100644
index 0000000..d6cc153
--- /dev/null
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -0,0 +1,335 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from pathlib import Path
+import sys
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.punctuation import PunctuationTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+
+
+class Text2Punc:
+
+    def __init__(
+        self,
+        train_config: Optional[str],
+        model_file: Optional[str],
+        device: str = "cpu",
+        dtype: str = "float32",
+    ):
+        #  Build Model
+        model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+        self.device = device
+        # Wrape model to make model.nll() data-parallel
+        self.wrapped_model = ForwardAdaptor(model, "inference")
+        self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+        # logging.info(f"Model:\n{model}")
+        self.punc_list = train_args.punc_list
+        self.period = 0
+        for i in range(len(self.punc_list)):
+            if self.punc_list[i] == ",":
+                self.punc_list[i] = "锛�"
+            elif self.punc_list[i] == "?":
+                self.punc_list[i] = "锛�"
+            elif self.punc_list[i] == "銆�":
+                self.period = i
+        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
+            train=False,
+            token_type=train_args.token_type,
+            token_list=train_args.token_list,
+            bpemodel=train_args.bpemodel,
+            text_cleaner=train_args.cleaner,
+            g2p_type=train_args.g2p,
+            text_name="text",
+            non_linguistic_symbols=train_args.non_linguistic_symbols,
+        )
+        print("start decoding!!!")
+
+    @torch.no_grad()
+    def __call__(self, text: Union[list, str], cache: list, split_size=20):
+        if cache is not None and len(cache) > 0:
+            precache = "".join(cache)
+        else:
+            precache = ""
+        data = {"text": precache + text}
+        result = self.preprocessor(data=data, uid="12938712838719")
+        split_text = self.preprocessor.pop_split_text_data(result)
+        mini_sentences = split_to_mini_sentence(split_text, split_size)
+        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+        assert len(mini_sentences) == len(mini_sentences_id)
+        cache_sent = []
+        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+        sentence_punc_list = []
+        sentence_words_list= []
+        cache_pop_trigger_limit = 200
+        skip_num = 0
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            data = {
+                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+                "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
+            }
+            data = to_device(data, self.device)
+            y, _ = self.wrapped_model(**data)
+            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+            punctuations = indices
+            if indices.size()[0] != 1:
+                punctuations = torch.squeeze(indices)
+            assert punctuations.size()[0] == len(mini_sentence)
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+    
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.period
+                cache_sent = mini_sentence[sentenceEnd + 1:]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+                mini_sentence = mini_sentence[0:sentenceEnd + 1]
+                punctuations = punctuations[0:sentenceEnd + 1]
+
+            punctuations_np = punctuations.cpu().numpy()
+            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
+            sentence_words_list += mini_sentence
+
+        assert len(sentence_punc_list) == len(sentence_words_list)
+        words_with_punc = []
+        sentence_punc_list_out = []
+        for i in range(0, len(sentence_words_list)):
+            if i > 0:
+                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+                    sentence_words_list[i] = " " + sentence_words_list[i]
+            if skip_num < len(cache):
+                skip_num += 1
+            else:
+                words_with_punc.append(sentence_words_list[i])
+            if skip_num >= len(cache):
+                sentence_punc_list_out.append(sentence_punc_list[i])
+                if sentence_punc_list[i] != "_":
+                    words_with_punc.append(sentence_punc_list[i])
+        sentence_out = "".join(words_with_punc)
+
+        sentenceEnd = -1
+        for i in range(len(sentence_punc_list) - 2, 1, -1):
+            if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
+               sentenceEnd = i
+               break
+        cache_out = sentence_words_list[sentenceEnd + 1 :]
+        if sentence_out[-1] in self.punc_list:
+            sentence_out = sentence_out[:-1]
+            sentence_punc_list_out[-1] = "_"
+        return sentence_out, sentence_punc_list_out, cache_out
+
+
+def inference(
+    batch_size: int,
+    dtype: str,
+    ngpu: int,
+    seed: int,
+    num_workers: int,
+    output_dir: str,
+    log_level: Union[int, str],
+    train_config: Optional[str],
+    model_file: Optional[str],
+    key_file: Optional[str] = None,
+    data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+    raw_inputs: Union[List[Any], bytes, str] = None,
+    cache: List[Any] = None,
+    param_dict: dict = None,
+    **kwargs,
+):
+    inference_pipeline = inference_modelscope(
+        output_dir=output_dir,
+        batch_size=batch_size,
+        dtype=dtype,
+        ngpu=ngpu,
+        seed=seed,
+        num_workers=num_workers,
+        log_level=log_level,
+        key_file=key_file,
+        train_config=train_config,
+        model_file=model_file,
+        param_dict=param_dict,
+        **kwargs,
+    )
+    return inference_pipeline(data_path_and_name_and_type, raw_inputs, cache)
+
+
+def inference_modelscope(
+    batch_size: int,
+    dtype: str,
+    ngpu: int,
+    seed: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    #cache: list,
+    key_file: Optional[str],
+    train_config: Optional[str],
+    model_file: Optional[str],
+    output_dir: Optional[str] = None,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+    text2punc = Text2Punc(train_config, model_file, device)
+
+    def _forward(
+        data_path_and_name_and_type,
+        raw_inputs: Union[List[Any], bytes, str] = None,
+        output_dir_v2: Optional[str] = None,
+        cache: List[Any] = None,
+        param_dict: dict = None,
+    ):
+        results = []
+        split_size = 10
+
+        if raw_inputs != None:
+            line = raw_inputs.strip()
+            key = "demo"
+            if line == "":
+                item = {'key': key, 'value': ""}
+                results.append(item)
+                return results
+            #import pdb;pdb.set_trace()
+            result, _, cache = text2punc(line, cache)
+            item = {'key': key, 'value': result, 'cache': cache}
+            results.append(item)
+            return results
+
+        for inference_text, _, _ in data_path_and_name_and_type:
+            with open(inference_text, "r", encoding="utf-8") as fin:
+                for line in fin:
+                    line = line.strip()
+                    segs = line.split("\t")
+                    if len(segs) != 2:
+                        continue
+                    key = segs[0]
+                    if len(segs[1]) == 0:
+                        continue
+                    result, _ = text2punc(segs[1])
+                    item = {'key': key, 'value': result}
+                    results.append(item)
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path != None:
+            output_file_name = "infer.out"
+            Path(output_path).mkdir(parents=True, exist_ok=True)
+            output_file_path = (Path(output_path) / output_file_name).absolute()
+            with open(output_file_path, "w", encoding="utf-8") as fout:
+                for item_i in results:
+                    key_out = item_i["key"]
+                    value_out = item_i["value"]
+                    fout.write(f"{key_out}\t{value_out}\n")
+        return results
+
+    return _forward
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="Punctuation inference",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=False)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
+    group.add_argument("--raw_inputs", type=str, required=False)
+    group.add_argument("--cache", type=list, required=False)
+    group.add_argument("--param_dict", type=dict, required=False)
+    group.add_argument("--key_file", type=str_or_none)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument("--train_config", type=str)
+    group.add_argument("--model_file", type=str)
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    # kwargs.pop("config", None)
+    inference(**kwargs)
+
+
+if __name__ == "__main__":
+    main()

--
Gitblit v1.9.1