lyblsgo
2023-01-18 49df645df7508eb8085af6f2762a474c2ebf4941
fix ctc module
4个文件已修改
30 ■■■■ 已修改文件
funasr/bin/asr_inference_paraformer.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_timestamp.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py
@@ -95,10 +95,13 @@
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        if asr_model.ctc != None:
            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = asr_model.token_list
        scorers.update(
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
funasr/bin/asr_inference_paraformer_timestamp.py
@@ -98,10 +98,13 @@
        logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        if asr_model.ctc != None:
            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = asr_model.token_list
        scorers.update(
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -100,10 +100,13 @@
        # logging.info("asr_train_args: {}".format(asr_train_args))
        asr_model.to(dtype=getattr(torch, dtype)).eval()
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        if asr_model.ctc != None:
            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = asr_model.token_list
        scorers.update(
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )
@@ -663,7 +666,7 @@
                                                                                         time_stamp_postprocessed))
        
        logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
                     format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
                     format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
        return asr_result_list
    return _forward
funasr/bin/asr_inference_uniasr.py
@@ -96,11 +96,14 @@
        else:
            decoder = asr_model.decoder2
        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
        if asr_model.ctc != None:
            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = asr_model.token_list
        scorers.update(
            decoder=decoder,
            ctc=ctc,
            length_bonus=LengthBonus(len(token_list)),
        )