游雁
2023-01-17 a063430904efc1518616fc77afdd3d30cc607b09
fixbug
1个文件已修改
19 ■■■■■ 已修改文件
funasr/bin/asr_inference_paraformer_vad_punc.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -562,6 +562,7 @@
        length_total = 0.0
        finish_count = 0
        file_count = 1
        lfr_factor = 6
        # 7 .Start for-loop
        asr_result_list = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
@@ -597,7 +598,7 @@
                    results = speech2text(**batch)
                    if len(results) < 1:
                        hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
                        results = [[" ", ["<space>"], [2], 10, 6]] * nbest
                        results = [[" ", ["<space>"], [2], 0, 1, 6]] * nbest
                    time_end = time.time()
                    forward_time = time_end - time_beg
                    lfr_factor = results[0][-1]
@@ -615,7 +616,8 @@
                
                key = keys[0]
                result = result_segments[0]
                text, token, token_int, time_stamp = result
                text, token, token_int = result[0], result[1], result[2]
                time_stamp = None if len(result) < 4 else result[3]
                
                # Create a directory: outdir/{n}best_recog
                if writer is not None:
@@ -634,11 +636,12 @@
                        text_postprocessed_punc_time_stamp = "predictions: {}  time_stamp: {}".format(
                            text_postprocessed_punc, time_stamp_postprocessed)
                    else:
                        text_postprocessed = postprocessed_result
                        time_stamp_postprocessed = None
                        word_lists = None
                        text_postprocessed_punc_time_stamp = None
                        punc_id_list = None
                        text_postprocessed = ""
                        time_stamp_postprocessed = ""
                        word_lists = ""
                        text_postprocessed_punc_time_stamp = ""
                        punc_id_list = ""
                        text_postprocessed_punc = ""
                    item = {'key': key, 'value': text_postprocessed_punc_time_stamp, 'text': text_postprocessed,
                            'time_stamp': time_stamp_postprocessed, 'punc': punc_id_list, 'token': token}
@@ -660,7 +663,7 @@
                                                                                         time_stamp_postprocessed))
        
        logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
                     format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
                     format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
        return asr_result_list
    return _forward