aky15
2023-05-09 77045e7bb78d4b8a82f96130f9d84e356a32d5c5
funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
@@ -69,6 +69,7 @@
            precache = "".join(cache)
        else:
            precache = ""
            cache = []
        data = {"text": precache + text}
        result = self.preprocessor(data=data, uid="12938712838719")
        split_text = self.preprocessor.pop_split_text_data(result)
@@ -225,7 +226,7 @@
    ):
        results = []
        split_size = 10
        cache_in = param_dict["cache"]
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "demo"
@@ -233,35 +234,12 @@
                item = {'key': key, 'value': ""}
                results.append(item)
                return results
            #import pdb;pdb.set_trace()
            result, _, cache = text2punc(line, cache)
            item = {'key': key, 'value': result, 'cache': cache}
            result, _, cache = text2punc(line, cache_in)
            param_dict["cache"] = cache
            item = {'key': key, 'value': result}
            results.append(item)
            return results
        for inference_text, _, _ in data_path_and_name_and_type:
            with open(inference_text, "r", encoding="utf-8") as fin:
                for line in fin:
                    line = line.strip()
                    segs = line.split("\t")
                    if len(segs) != 2:
                        continue
                    key = segs[0]
                    if len(segs[1]) == 0:
                        continue
                    result, _ = text2punc(segs[1])
                    item = {'key': key, 'value': result}
                    results.append(item)
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path != None:
            output_file_name = "infer.out"
            Path(output_path).mkdir(parents=True, exist_ok=True)
            output_file_path = (Path(output_path) / output_file_name).absolute()
            with open(output_file_path, "w", encoding="utf-8") as fout:
                for item_i in results:
                    key_out = item_i["key"]
                    value_out = item_i["value"]
                    fout.write(f"{key_out}\t{value_out}\n")
        return results
    return _forward