游雁
2023-05-22 3cde6fe8e0457a0cd3d65797d6d53bd947047a44
bugfix paraformer large long
3个文件已修改
35 ■■■■ 已修改文件
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/version.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr_vad_punc/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/demo.py
@@ -3,15 +3,14 @@
if __name__ == '__main__':
    audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav'
    output_dir = None
    output_dir = "./results"
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
        vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
        punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
        output_dir=output_dir,
        batch_size=64,
    )
    rec_result = inference_pipeline(audio_in=audio_in)
    rec_result = inference_pipeline(audio_in=audio_in, batch_size_token=5000)
    print(rec_result)
funasr/bin/asr_inference_launch.py
@@ -600,6 +600,9 @@
        if 'hotword' in kwargs:
            hotword_list_or_file = kwargs['hotword']
        
        batch_size_token = kwargs.get("batch_size_token", 6000)
        print("batch_size_token: ", batch_size_token)
        if speech2text.hotword_list is None:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        
@@ -642,8 +645,10 @@
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            beg_vad = time.time()
            vad_results = speech2vadsegment(**batch)
            end_vad = time.time()
            print("time cost vad: ", end_vad-beg_vad)
            _, vadsegments = vad_results[0], vad_results[1][0]
            
            speech, speech_lengths = batch["speech"], batch["speech_lengths"]
@@ -652,17 +657,29 @@
            data_with_index = [(vadsegments[i], i) for i in range(n)]
            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
            results_sorted = []
            for j, beg_idx in enumerate(range(0, n, batch_size)):
                end_idx = min(n, beg_idx + batch_size)
            batch_size_token_ms = batch_size_token*60
            batch_size_token_ms_cum = 0
            beg_idx = 0
            for j, _ in enumerate(range(0, n)):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
                speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
                beg_idx = end_idx
                batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
                batch = to_device(batch, device=device)
                print("batch: ", speech_j.shape[0])
                beg_asr = time.time()
                results = speech2text(**batch)
                end_asr = time.time()
                print("time cost asr: ", end_asr - beg_asr)
                
                if len(results) < 1:
                    results = [["", [], [], [], [], [], []]]
                results_sorted.extend(results)
            restored_data = [0] * n
            for j in range(n):
                index = sorted_data[j][1]
@@ -701,7 +718,10 @@
            text_postprocessed_punc = text_postprocessed
            punc_id_list = []
            if len(word_lists) > 0 and text2punc is not None:
                beg_punc = time.time()
                text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
                end_punc = time.time()
                print("time cost punc: ", end_punc-beg_punc)
            
            item = {'key': key, 'value': text_postprocessed_punc}
            if text_postprocessed != "":
funasr/version.txt
@@ -1 +1 @@
0.5.4
0.5.5