From 1d97d628f2f19674fa50495e984db8185604ca8e Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期五, 03 二月 2023 14:11:22 +0800
Subject: [PATCH] Merge branch 'main' into dev
---
funasr/bin/punctuation_infer.py | 328 +++++++++++++++++++++++-------------------------------
1 files changed, 138 insertions(+), 190 deletions(-)
diff --git a/funasr/bin/punctuation_infer.py b/funasr/bin/punctuation_infer.py
index b38ff94..a801ee8 100644
--- a/funasr/bin/punctuation_infer.py
+++ b/funasr/bin/punctuation_infer.py
@@ -3,33 +3,141 @@
import logging
from pathlib import Path
import sys
-import os
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
-from typing import Dict
from typing import Any
from typing import List
import numpy as np
import torch
-from torch.nn.parallel import data_parallel
from typeguard import check_argument_types
-from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
from funasr.utils.cli_utils import get_commandline_args
-from funasr.fileio.datadir_writer import DatadirWriter
from funasr.tasks.punctuation import PunctuationTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
-from funasr.utils.types import float_or_none
-from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
+from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+
+
+class Text2Punc:
+
+ def __init__(
+ self,
+ train_config: Optional[str],
+ model_file: Optional[str],
+ device: str = "cpu",
+ dtype: str = "float32",
+ ):
+ # Build Model
+ model, train_args = PunctuationTask.build_model_from_file(train_config, model_file, device)
+ self.device = device
+ # Wrape model to make model.nll() data-parallel
+ self.wrapped_model = ForwardAdaptor(model, "inference")
+ self.wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+ # logging.info(f"Model:\n{model}")
+ self.punc_list = train_args.punc_list
+ self.period = 0
+ for i in range(len(self.punc_list)):
+ if self.punc_list[i] == ",":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "?":
+ self.punc_list[i] = "锛�"
+ elif self.punc_list[i] == "銆�":
+ self.period = i
+ self.preprocessor = CodeMixTokenizerCommonPreprocessor(
+ train=False,
+ token_type=train_args.token_type,
+ token_list=train_args.token_list,
+ bpemodel=train_args.bpemodel,
+ text_cleaner=train_args.cleaner,
+ g2p_type=train_args.g2p,
+ text_name="text",
+ non_linguistic_symbols=train_args.non_linguistic_symbols,
+ )
+ print("start decoding!!!")
+
+ @torch.no_grad()
+ def __call__(self, text: Union[list, str], split_size=20):
+ data = {"text": text}
+ result = self.preprocessor(data=data, uid="12938712838719")
+ split_text = self.preprocessor.pop_split_text_data(result)
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
+ mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+ assert len(mini_sentences) == len(mini_sentences_id)
+ cache_sent = []
+ cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+ new_mini_sentence = ""
+ new_mini_sentence_punc = []
+ cache_pop_trigger_limit = 200
+ for mini_sentence_i in range(len(mini_sentences)):
+ mini_sentence = mini_sentences[mini_sentence_i]
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
+ mini_sentence = cache_sent + mini_sentence
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+ data = {
+ "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+ "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+ }
+ data = to_device(data, self.device)
+ y, _ = self.wrapped_model(**data)
+ _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+ punctuations = indices
+ if indices.size()[0] != 1:
+ punctuations = torch.squeeze(indices)
+ assert punctuations.size()[0] == len(mini_sentence)
+
+ # Search for the last Period/QuestionMark as cache
+ if mini_sentence_i < len(mini_sentences) - 1:
+ sentenceEnd = -1
+ last_comma_index = -1
+ for i in range(len(punctuations) - 2, 1, -1):
+ if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ sentenceEnd = i
+ break
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+ last_comma_index = i
+
+ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ # The sentence it too long, cut off at a comma.
+ sentenceEnd = last_comma_index
+ punctuations[sentenceEnd] = self.period
+ cache_sent = mini_sentence[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+ mini_sentence = mini_sentence[0:sentenceEnd + 1]
+ punctuations = punctuations[0:sentenceEnd + 1]
+
+ # if len(punctuations) == 0:
+ # continue
+
+ punctuations_np = punctuations.cpu().numpy()
+ new_mini_sentence_punc += [int(x) for x in punctuations_np]
+ words_with_punc = []
+ for i in range(len(mini_sentence)):
+ if i > 0:
+ if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
+ words_with_punc.append(mini_sentence[i])
+ if self.punc_list[punctuations[i]] != "_":
+ words_with_punc.append(self.punc_list[punctuations[i]])
+ new_mini_sentence += "".join(words_with_punc)
+ # Add Period for the end of the sentence
+ new_mini_sentence_out = new_mini_sentence
+ new_mini_sentence_punc_out = new_mini_sentence_punc
+ if mini_sentence_i == len(mini_sentences) - 1:
+ if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+ new_mini_sentence_out = new_mini_sentence + "銆�"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ return new_mini_sentence_out, new_mini_sentence_punc_out
def inference(
@@ -45,12 +153,12 @@
key_file: Optional[str] = None,
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
raw_inputs: Union[List[Any], bytes, str] = None,
-
+ cache: List[Any] = None,
+ param_dict: dict = None,
**kwargs,
):
inference_pipeline = inference_modelscope(
output_dir=output_dir,
- raw_inputs=raw_inputs,
batch_size=batch_size,
dtype=dtype,
ngpu=ngpu,
@@ -60,6 +168,7 @@
key_file=key_file,
train_config=train_config,
model_file=model_file,
+ param_dict=param_dict,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -76,6 +185,7 @@
train_config: Optional[str],
model_file: Optional[str],
output_dir: Optional[str] = None,
+ param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
@@ -91,41 +201,14 @@
# 1. Set random-seed
set_all_random_seed(seed)
-
- # 2. Build Model
- model, train_args = PunctuationTask.build_model_from_file(
- train_config, model_file, device)
- # Wrape model to make model.nll() data-parallel
- wrapped_model = ForwardAdaptor(model, "inference")
- wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- logging.info(f"Model:\n{model}")
- punc_list = train_args.punc_list
- period = 0
- for i in range(len(punc_list)):
- if punc_list[i] == ",":
- punc_list[i] = "锛�"
- elif punc_list[i] == "?":
- punc_list[i] = "锛�"
- elif punc_list[i] == "銆�":
- period = i
-
- preprocessor = CommonPreprocessor(
- train=False,
- token_type="word",
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- )
-
- print("start decoding!!!")
+ text2punc = Text2Punc(train_config, model_file, device)
def _forward(
data_path_and_name_and_type,
raw_inputs: Union[List[Any], bytes, str] = None,
output_dir_v2: Optional[str] = None,
+ cache: List[Any] = None,
+ param_dict: dict = None,
):
results = []
split_size = 20
@@ -133,77 +216,14 @@
if raw_inputs != None:
line = raw_inputs.strip()
key = "demo"
- if line=="":
+ if line == "":
item = {'key': key, 'value': ""}
results.append(item)
return results
- cache_sent = []
- words = split_words(line)
- new_mini_sentence = ""
- new_mini_sentence_punc = ""
- cache_pop_trigger_limit = 200
- mini_sentences = split_to_mini_sentence(words, split_size)
- for mini_sentence_i in range(len(mini_sentences)):
- mini_sentence = mini_sentences[mini_sentence_i]
- mini_sentence = cache_sent + mini_sentence
- data = {"text": " ".join(mini_sentence)}
- batch = preprocessor(data=data, uid="12938712838719")
- batch["text_lengths"] = torch.from_numpy(
- np.array([len(batch["text"])], dtype='int32'))
- batch["text"] = torch.from_numpy(batch["text"])
- # Extend one dimension to fake a batch dim.
- batch["text"] = torch.unsqueeze(batch["text"], 0)
- batch = to_device(batch, device)
- y, _ = wrapped_model(**batch)
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- if indices.size()[0] != 1:
- punctuations = torch.squeeze(indices)
- assert punctuations.size()[0] == len(mini_sentence)
-
- # Search for the last Period/QuestionMark as cache
- if mini_sentence_i < len(mini_sentences)-1:
- sentenceEnd = -1
- last_comma_index = -1
- for i in range(len(punctuations)-2,1,-1):
- if punc_list[punctuations[i]] == "銆�" or punc_list[punctuations[i]] == "锛�":
- sentenceEnd = i
- break
- if last_comma_index < 0 and punc_list[punctuations[i]] == "锛�":
- last_comma_index = i
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
- # The sentence it too long, cut off at a comma.
- sentenceEnd = last_comma_index
- punctuations[sentenceEnd] = period
- cache_sent = mini_sentence[sentenceEnd+1:]
- mini_sentence = mini_sentence[0:sentenceEnd+1]
- punctuations = punctuations[0:sentenceEnd+1]
-
- punctuations_np = punctuations.cpu().numpy()
- new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
- words_with_punc = []
- for i in range(len(mini_sentence)):
- if i>0:
- if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
- mini_sentence[i] = " "+ mini_sentence[i]
- words_with_punc.append(mini_sentence[i])
- if punc_list[punctuations[i]] != "_":
- words_with_punc.append(punc_list[punctuations[i]])
- new_mini_sentence += "".join(words_with_punc)
-
- # Add Period for the end of the sentence
- new_mini_sentence_out = new_mini_sentence
- new_mini_sentence_punc_out = new_mini_sentence_punc
- if mini_sentence_i == len(mini_sentences)-1:
- if new_mini_sentence[-1]=="锛�" or new_mini_sentence[-1]=="銆�":
- new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
- elif new_mini_sentence[-1]!="銆�" and new_mini_sentence[-1]!="锛�":
- new_mini_sentence_out=new_mini_sentence+"銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
- item = {'key': key, 'value': new_mini_sentence_out}
- results.append(item)
-
+ result, _ = text2punc(line)
+ item = {'key': key, 'value': result}
+ results.append(item)
+ print(results)
return results
for inference_text, _, _ in data_path_and_name_and_type:
@@ -216,72 +236,9 @@
key = segs[0]
if len(segs[1]) == 0:
continue
- cache_sent = []
- words = split_words(segs[1])
- new_mini_sentence = ""
- new_mini_sentence_punc = ""
- cache_pop_trigger_limit = 200
- mini_sentences = split_to_mini_sentence(words, split_size)
- for mini_sentence_i in range(len(mini_sentences)):
- mini_sentence = mini_sentences[mini_sentence_i]
- mini_sentence = cache_sent + mini_sentence
- data = {"text": " ".join(mini_sentence)}
- batch = preprocessor(data=data, uid="12938712838719")
- batch["text_lengths"] = torch.from_numpy(
- np.array([len(batch["text"])], dtype='int32'))
- batch["text"] = torch.from_numpy(batch["text"])
- # Extend one dimension to fake a batch dim.
- batch["text"] = torch.unsqueeze(batch["text"], 0)
- batch = to_device(batch, device)
- y, _ = wrapped_model(**batch)
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- if indices.size()[0] != 1:
- punctuations = torch.squeeze(indices)
- assert punctuations.size()[0] == len(mini_sentence)
-
- # Search for the last Period/QuestionMark as cache
- if mini_sentence_i < len(mini_sentences)-1:
- sentenceEnd = -1
- last_comma_index = -1
- for i in range(len(punctuations)-2,1,-1):
- if punc_list[punctuations[i]] == "銆�" or punc_list[punctuations[i]] == "锛�":
- sentenceEnd = i
- break
- if last_comma_index < 0 and punc_list[punctuations[i]] == "锛�":
- last_comma_index = i
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
- # The sentence it too long, cut off at a comma.
- sentenceEnd = last_comma_index
- punctuations[sentenceEnd] = period
- cache_sent = mini_sentence[sentenceEnd+1:]
- mini_sentence = mini_sentence[0:sentenceEnd+1]
- punctuations = punctuations[0:sentenceEnd+1]
-
- punctuations_np = punctuations.cpu().numpy()
- new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
- words_with_punc = []
- for i in range(len(mini_sentence)):
- if i>0:
- if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i-1][0].encode()) == 1:
- mini_sentence[i] = " "+ mini_sentence[i]
- words_with_punc.append(mini_sentence[i])
- if punc_list[punctuations[i]] != "_":
- words_with_punc.append(punc_list[punctuations[i]])
- new_mini_sentence += "".join(words_with_punc)
-
- # Add Period for the end of the sentence
- new_mini_sentence_out = new_mini_sentence
- new_mini_sentence_punc_out = new_mini_sentence_punc
- if mini_sentence_i == len(mini_sentences)-1:
- if new_mini_sentence[-1]=="锛�" or new_mini_sentence[-1]=="銆�":
- new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
- elif new_mini_sentence[-1]!="銆�" and new_mini_sentence[-1]!="锛�":
- new_mini_sentence_out=new_mini_sentence+"銆�"
- new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + str(period)
- item = {'key': key, 'value': new_mini_sentence_out}
- results.append(item)
+ result, _ = text2punc(segs[1])
+ item = {'key': key, 'value': result}
+ results.append(item)
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path != None:
output_file_name = "infer.out"
@@ -293,6 +250,7 @@
value_out = item_i["value"]
fout.write(f"{key_out}\t{value_out}\n")
return results
+
return _forward
@@ -338,19 +296,11 @@
)
group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- required=False
- )
- group.add_argument(
- "--raw_inputs",
- type=str,
- required=False
- )
+ group.add_argument("--data_path_and_name_and_type", type=str2triple_str, action="append", required=False)
+ group.add_argument("--raw_inputs", type=str, required=False)
+ group.add_argument("--cache", type=list, required=False)
+ group.add_argument("--param_dict", type=dict, required=False)
group.add_argument("--key_file", type=str_or_none)
-
group = parser.add_argument_group("The model configuration related")
group.add_argument("--train_config", type=str)
@@ -364,11 +314,9 @@
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
- # kwargs.pop("config", None)
+ # kwargs.pop("config", None)
inference(**kwargs)
+
if __name__ == "__main__":
main()
-
-
-
--
Gitblit v1.9.1