| | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.punctuation.text_preprocessor import split_to_mini_sentence |
| | | from funasr.datasets.preprocessor import split_to_mini_sentence |
| | | |
| | | |
| | | class Text2Punc: |
| | |
| | | text_name="text", |
| | | non_linguistic_symbols=train_args.non_linguistic_symbols, |
| | | ) |
| | | print("start decoding!!!") |
| | | |
| | | |
| | | @torch.no_grad() |
| | | def __call__(self, text: Union[list, str], cache: list, split_size=20): |
| | |
| | | else: |
| | | precache = "" |
| | | cache = [] |
| | | data = {"text": precache + text} |
| | | data = {"text": precache + " " + text} |
| | | result = self.preprocessor(data=data, uid="12938712838719") |
| | | split_text = self.preprocessor.pop_split_text_data(result) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | |
| | | data = { |
| | | "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0), |
| | | "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')), |
| | | "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')), |
| | | "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')), |
| | | } |
| | | data = to_device(data, self.device) |
| | | y, _ = self.wrapped_model(**data) |
| | |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | logging.basicConfig( |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | ncpu = kwargs.get("ncpu", 1) |
| | | torch.set_num_threads(ncpu) |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | device = "cuda" |
| | |
| | | ): |
| | | results = [] |
| | | split_size = 10 |
| | | |
| | | cache_in = param_dict["cache"] |
| | | if raw_inputs != None: |
| | | line = raw_inputs.strip() |
| | | key = "demo" |
| | |
| | | item = {'key': key, 'value': ""} |
| | | results.append(item) |
| | | return results |
| | | result, _, cache = text2punc(line, cache) |
| | | item = {'key': key, 'value': result, 'cache': cache} |
| | | result, _, cache = text2punc(line, cache_in) |
| | | param_dict["cache"] = cache |
| | | item = {'key': key, 'value': result} |
| | | results.append(item) |
| | | return results |
| | | |
| | | for inference_text, _, _ in data_path_and_name_and_type: |
| | | with open(inference_text, "r", encoding="utf-8") as fin: |
| | | for line in fin: |
| | | line = line.strip() |
| | | segs = line.split("\t") |
| | | if len(segs) != 2: |
| | | continue |
| | | key = segs[0] |
| | | if len(segs[1]) == 0: |
| | | continue |
| | | result, _ = text2punc(segs[1]) |
| | | item = {'key': key, 'value': result} |
| | | results.append(item) |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path != None: |
| | | output_file_name = "infer.out" |
| | | Path(output_path).mkdir(parents=True, exist_ok=True) |
| | | output_file_path = (Path(output_path) / output_file_name).absolute() |
| | | with open(output_file_path, "w", encoding="utf-8") as fout: |
| | | for item_i in results: |
| | | key_out = item_i["key"] |
| | | value_out = item_i["value"] |
| | | fout.write(f"{key_out}\t{value_out}\n") |
| | | return results |
| | | |
| | | return _forward |