Truco
2024-10-28 1a45b647a858b72c711c64bfe6b2e5333df6fd86
perf(models/FsmnVADStreaming): optimize GetFrameState and PopDataToOutputBuf (#2177)

- In GetFrameState(), pass generator to sum() instead of generating a list, ~10% gain in a 21s sample
- In GetFrameState(), cast `sum_score` (a tensor) to float to reduce calling to tensor lib,
~13% gain in a 23s example
- In PopDataToOutputBuf(), remove unused `out_pos` and related calculation, ~10% gain in a 27s sample
1个文件已修改
22 ■■■■ 已修改文件
funasr/models/fsmn_vad_streaming/model.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/model.py
@@ -406,7 +406,6 @@
        cur_seg = cache["stats"].output_data_buf[-1]
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print("warning\n")
        out_pos = len(cur_seg.buffer)  # cur_seg.buff现在没做任何操作
        data_to_pop = 0
        if end_point_is_sent_end:
            data_to_pop = expected_sample_number
@@ -420,12 +419,6 @@
            expected_sample_number = len(cache["stats"].data_buf)
        cur_seg.doa = 0
        for sample_cpy_out in range(0, data_to_pop):
            # cur_seg.buffer[out_pos ++] = data_buf_.back();
            out_pos += 1
        for sample_cpy_out in range(data_to_pop, expected_sample_number):
            # cur_seg.buffer[out_pos++] = data_buf_.back()
            out_pos += 1
        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
            print("Something wrong with the VAD algorithm\n")
        cache["stats"].data_buf_start_frame += frm_cnt
@@ -512,10 +505,17 @@
        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
        if len(cache["stats"].sil_pdf_ids) > 0:
            assert len(cache["stats"].scores) == 1  # 只支持batch_size = 1的测试
            sil_pdf_scores = [
                cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids
            ]
            sum_score = sum(sil_pdf_scores)
            """
            - Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
              and `torch.Tensor` is slower `float` when tensor has only one element.
            - Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
            - The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
              generation, which result in a mere (~2%) performance gain.
            """
            if len(cache["stats"].sil_pdf_ids) > 1:
                sum_score = sum(cache["stats"].scores[0][t][sil_pdf_id].item() for sil_pdf_id in cache["stats"].sil_pdf_ids)
            else:
                sum_score = cache["stats"].scores[0][t][cache["stats"].sil_pdf_ids[0]].item()
            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
            total_score = 1.0
            sum_score = total_score - sum_score