From 1a45b647a858b72c711c64bfe6b2e5333df6fd86 Mon Sep 17 00:00:00 2001
From: Truco <22969604+truc0@users.noreply.github.com>
Date: 星期一, 28 十月 2024 13:41:38 +0800
Subject: [PATCH] perf(models/FsmnVADStreaming): optimize GetFrameState and PopDataToOutputBuf (#2177)
---
funasr/models/fsmn_vad_streaming/model.py | 22 +++++++++++-----------
1 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 03aaca7..bfffca8 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -406,7 +406,6 @@
cur_seg = cache["stats"].output_data_buf[-1]
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print("warning\n")
- out_pos = len(cur_seg.buffer) # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
data_to_pop = 0
if end_point_is_sent_end:
data_to_pop = expected_sample_number
@@ -420,12 +419,6 @@
expected_sample_number = len(cache["stats"].data_buf)
cur_seg.doa = 0
- for sample_cpy_out in range(0, data_to_pop):
- # cur_seg.buffer[out_pos ++] = data_buf_.back();
- out_pos += 1
- for sample_cpy_out in range(data_to_pop, expected_sample_number):
- # cur_seg.buffer[out_pos++] = data_buf_.back()
- out_pos += 1
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print("Something wrong with the VAD algorithm\n")
cache["stats"].data_buf_start_frame += frm_cnt
@@ -512,10 +505,17 @@
assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(cache["stats"].sil_pdf_ids) > 0:
assert len(cache["stats"].scores) == 1 # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
- sil_pdf_scores = [
- cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids
- ]
- sum_score = sum(sil_pdf_scores)
+ """
+ - Change type of `sum_score` to float. The reason is that `sum_score` is a tensor with single element.
+ and `torch.Tensor` is slower `float` when tensor has only one element.
+ - Put the iteration of `sil_pdf_ids` inside `sum()` to reduce the overhead of creating a new list.
+ - The default `sil_pdf_ids` is [0], the `if` statement is used to reduce the overhead of expression
+ generation, which result in a mere (~2%) performance gain.
+ """
+ if len(cache["stats"].sil_pdf_ids) > 1:
+ sum_score = sum(cache["stats"].scores[0][t][sil_pdf_id].item() for sil_pdf_id in cache["stats"].sil_pdf_ids)
+ else:
+ sum_score = cache["stats"].scores[0][t][cache["stats"].sil_pdf_ids[0]].item()
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
sum_score = total_score - sum_score
--
Gitblit v1.9.1