| | |
| | | vad_infer_config: Optional[str] = None, |
| | | vad_model_file: Optional[str] = None, |
| | | vad_cmvn_file: Optional[str] = None, |
| | | time_stamp_writer: bool = False, |
| | | time_stamp_writer: bool = True, |
| | | punc_infer_config: Optional[str] = None, |
| | | punc_model_file: Optional[str] = None, |
| | | outputs_dict: Optional[bool] = True, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | |
| | | |
| | | text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype) |
| | | |
| | | if output_dir is not None: |
| | | writer = DatadirWriter(output_dir) |
| | | ibest_writer = writer[f"1best_recog"] |
| | | ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list) |
| | | |
| | | def _forward(data_path_and_name_and_type, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path is not None: |
| | | writer = DatadirWriter(output_path) |
| | | ibest_writer = writer[f"1best_recog"] |
| | | # ibest_writer["punc_dict"][""] = " ".join(punc_infer_config.punc_list) |
| | | # ibest_writer["token_list"][""] = " ".join(asr_train_config.token_list) |
| | | else: |
| | | writer = None |
| | | |
| | |
| | | |
| | | # Create a directory: outdir/{n}best_recog |
| | | if writer is not None: |
| | | ibest_writer = writer[f"1best_recog"] |
| | | |
| | | # Write the result to each file |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["vad"][key] = "{}".format(vadsegments) |
| | | |
| | | if text is not None: |
| | | postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp) |
| | |
| | | punc_id_list = None |
| | | |
| | | item = {'key': key, 'value': text_postprocessed_punc_time_stamp, 'text': text_postprocessed, |
| | | 'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list} |
| | | 'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list, 'token': token} |
| | | if outputs_dict: |
| | | item = {'text_punc': text_postprocessed_punc, 'text': text_postprocessed, |
| | | 'punc_id': punc_id_list, 'token': token, 'time_stamp': time_stamp_postprocessed} |
| | | item = {'key': key, 'value': item} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | # asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = text_postprocessed |
| | | if time_stamp_writer and time_stamp_postprocessed is not None: |
| | | ibest_writer["time_stamp"][key] = " ".join( |
| | | ["-".join(map(str, ts)) for ts in time_stamp_postprocessed]) |
| | | ibest_writer["punc_id"][key] = "{}".format(punc_id_list) |
| | | ibest_writer["text_with_punc"][key] = text_postprocessed_punc_time_stamp |
| | | if time_stamp_postprocessed is not None: |
| | | ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed) |
| | | |
| | | logging.info("decoding, utt: {}, predictions: {}, time_stamp: {}".format(key, text_postprocessed_punc, |
| | | time_stamp_postprocessed)) |
| | |
| | | punc_list[i] = "?" |
| | | elif punc_list[i] == "。": |
| | | period = i |
| | | |
| | | preprocessor = CommonPreprocessor( |
| | | train=False, |
| | | token_type="word", |
| | |
| | | cache_sent = [] |
| | | mini_sentences = split_to_mini_sentence(words, split_size) |
| | | new_mini_sentence = "" |
| | | new_mini_sentence_punc = "" |
| | | new_mini_sentence_punc = [] |
| | | cache_pop_trigger_limit = 200 |
| | | for mini_sentence_i in range(len(mini_sentences)): |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | |
| | | # Search for the last Period/QuestionMark as cache |
| | | if mini_sentence_i < len(mini_sentences) - 1: |
| | | sentenceEnd = -1 |
| | | last_comma_index = -1 |
| | | for i in range(len(punctuations) - 2, 1, -1): |
| | | if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?": |
| | | sentenceEnd = i |
| | | break |
| | | if last_comma_index < 0 and punc_list[punctuations[i]] == ",": |
| | | last_comma_index = i |
| | | |
| | | if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: |
| | | # The sentence it too long, cut off at a comma. |
| | | sentenceEnd = last_comma_index |
| | | punctuations[sentenceEnd] = period |
| | | cache_sent = mini_sentence[sentenceEnd + 1:] |
| | | mini_sentence = mini_sentence[0:sentenceEnd + 1] |
| | | punctuations = punctuations[0:sentenceEnd + 1] |
| | |
| | | # continue |
| | | |
| | | punctuations_np = punctuations.cpu().numpy() |
| | | new_mini_sentence_punc += "".join([str(x) for x in punctuations_np]) |
| | | new_mini_sentence_punc += [int(x) for x in punctuations_np] |
| | | words_with_punc = [] |
| | | for i in range(len(mini_sentence)): |
| | | if i > 0: |
| | |
| | | new_mini_sentence += "".join(words_with_punc) |
| | | |
| | | return new_mini_sentence, new_mini_sentence_punc |
| | | |
| | | return _forward |
| | | |
| | | def get_parser(): |