游雁
2023-01-17 5ddad6db687ed45bc0b38cfc802ddfc3ab8c7f68
modelscope paraformer large long input
1个文件已修改
43 ■■■■ 已修改文件
funasr/bin/asr_inference_paraformer_vad_punc.py 43 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -666,9 +666,10 @@
    vad_infer_config: Optional[str] = None,
    vad_model_file: Optional[str] = None,
    vad_cmvn_file: Optional[str] = None,
    time_stamp_writer: bool = False,
    time_stamp_writer: bool = True,
    punc_infer_config: Optional[str] = None,
    punc_model_file: Optional[str] = None,
    outputs_dict: Optional[bool] = True,
    **kwargs,
):
    assert check_argument_types()
@@ -726,6 +727,11 @@
    
    text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
    
    if output_dir is not None:
        writer = DatadirWriter(output_dir)
        ibest_writer = writer[f"1best_recog"]
        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
@@ -756,6 +762,9 @@
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path is not None:
            writer = DatadirWriter(output_path)
            ibest_writer = writer[f"1best_recog"]
            # ibest_writer["punc_dict"][""] = " ".join(punc_infer_config.punc_list)
            # ibest_writer["token_list"][""] = " ".join(asr_train_config.token_list)
        else:
            writer = None
        
@@ -805,11 +814,10 @@
                
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
                    ibest_writer = writer[f"1best_recog"]
                    # Write the result to each file
                    ibest_writer["token"][key] = " ".join(token)
                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                    ibest_writer["vad"][key] = "{}".format(vadsegments)
                
                if text is not None:
                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
@@ -828,15 +836,20 @@
                        punc_id_list = None
                    
                    item = {'key': key, 'value': text_postprocessed_punc_time_stamp, 'text': text_postprocessed,
                            'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list}
                            'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list, 'token': token}
                    if outputs_dict:
                        item = {'text_punc': text_postprocessed_punc, 'text': text_postprocessed,
                                'punc_id': punc_id_list, 'token': token, 'time_stamp': time_stamp_postprocessed}
                        item = {'key': key, 'value': item}
                    asr_result_list.append(item)
                    finish_count += 1
                    # asr_utils.print_progress(finish_count / file_count)
                    if writer is not None:
                        ibest_writer["text"][key] = text_postprocessed
                        if time_stamp_writer and time_stamp_postprocessed is not None:
                            ibest_writer["time_stamp"][key] = " ".join(
                                ["-".join(map(str, ts)) for ts in time_stamp_postprocessed])
                        ibest_writer["punc_id"][key] = "{}".format(punc_id_list)
                        ibest_writer["text_with_punc"][key] = text_postprocessed_punc_time_stamp
                        if time_stamp_postprocessed is not None:
                            ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
                
                logging.info("decoding, utt: {}, predictions: {}, time_stamp: {}".format(key, text_postprocessed_punc,
                                                                                         time_stamp_postprocessed))
@@ -869,7 +882,6 @@
            punc_list[i] = "?"
        elif punc_list[i] == "。":
            period = i
    preprocessor = CommonPreprocessor(
        train=False,
        token_type="word",
@@ -887,7 +899,8 @@
        cache_sent = []
        mini_sentences = split_to_mini_sentence(words, split_size)
        new_mini_sentence = ""
        new_mini_sentence_punc = ""
        new_mini_sentence_punc = []
        cache_pop_trigger_limit = 200
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
@@ -908,11 +921,18 @@
            # Search for the last Period/QuestionMark as cache
            if mini_sentence_i < len(mini_sentences) - 1:
                sentenceEnd = -1
                last_comma_index = -1
                for i in range(len(punctuations) - 2, 1, -1):
                    if punc_list[punctuations[i]] == "。" or punc_list[punctuations[i]] == "?":
                        sentenceEnd = i
                        break
                    if last_comma_index < 0 and punc_list[punctuations[i]] == ",":
                        last_comma_index = i
                
                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
                    # The sentence it too long, cut off at a comma.
                    sentenceEnd = last_comma_index
                    punctuations[sentenceEnd] = period
                cache_sent = mini_sentence[sentenceEnd + 1:]
                mini_sentence = mini_sentence[0:sentenceEnd + 1]
                punctuations = punctuations[0:sentenceEnd + 1]
@@ -921,7 +941,7 @@
            #    continue
            
            punctuations_np = punctuations.cpu().numpy()
            new_mini_sentence_punc += "".join([str(x) for x in punctuations_np])
            new_mini_sentence_punc += [int(x) for x in punctuations_np]
            words_with_punc = []
            for i in range(len(mini_sentence)):
                if i > 0:
@@ -933,7 +953,6 @@
            new_mini_sentence += "".join(words_with_punc)
            
        return new_mini_sentence, new_mini_sentence_punc
    return _forward
def get_parser():