From 8b1be8c3cba8987e9993619b46e59039ef3d6560 Mon Sep 17 00:00:00 2001
From: Tang Linjie <tanglinjie@cug.edu.cn>
Date: 星期六, 30 十一月 2024 13:05:39 +0800
Subject: [PATCH] feat: speed up fbank's lfr (#2246)
---
funasr/frontends/wav_frontend.py | 55 +++++++++++++++++++++++++++----------------------------
1 files changed, 27 insertions(+), 28 deletions(-)
diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py
index 3324208..da23f9c 100644
--- a/funasr/frontends/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -62,17 +62,16 @@
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
- for i in range(T_lfr):
- if lfr_m <= T - i * lfr_n:
- LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
- else: # process last LFR frame
- num_padding = lfr_m - (T - i * lfr_n)
- frame = (inputs[i * lfr_n :]).view(-1)
- for _ in range(num_padding):
- frame = torch.hstack((frame, inputs[-1]))
- LFR_inputs.append(frame)
- LFR_outputs = torch.vstack(LFR_inputs)
- return LFR_outputs.type(torch.float32)
+ feat_dim = inputs.shape[-1]
+ strides = (lfr_n * feat_dim, 1)
+ sizes = (T_lfr, lfr_m * feat_dim)
+ last_idx = (T - lfr_m) // lfr_n + 1
+ num_padding = lfr_m - (T - last_idx * lfr_n)
+ if num_padding > 0:
+ num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
+ inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
+ LFR_outputs = inputs.as_strided(sizes, strides)
+ return LFR_outputs.clone().type(torch.float32)
@tables.register("frontend_classes", "wav_frontend")
@@ -289,24 +288,24 @@
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
- for i in range(T_lfr):
- if lfr_m <= T - i * lfr_n:
- LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
- else: # process last LFR frame
- if is_final:
- num_padding = lfr_m - (T - i * lfr_n)
- frame = (inputs[i * lfr_n :]).view(-1)
- for _ in range(num_padding):
- frame = torch.hstack((frame, inputs[-1]))
- LFR_inputs.append(frame)
- else:
- # update splice_idx and break the circle
- splice_idx = i
- break
+ feat_dim = inputs.shape[-1]
+ ori_inputs = inputs
+ strides = (lfr_n * feat_dim, 1)
+ sizes = (T_lfr, lfr_m * feat_dim)
+ last_idx = (T - lfr_m) // lfr_n + 1
+ num_padding = lfr_m - (T - last_idx * lfr_n)
+ if is_final:
+ if num_padding > 0:
+ num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
+ inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
+ else:
+ if num_padding > 0:
+ sizes = (last_idx, lfr_m * feat_dim)
+ splice_idx = last_idx
splice_idx = min(T - 1, splice_idx * lfr_n)
- lfr_splice_cache = inputs[splice_idx:, :]
- LFR_outputs = torch.vstack(LFR_inputs)
- return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
+ LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
+ lfr_splice_cache = ori_inputs[splice_idx:, :]
+ return LFR_outputs.clone().type(torch.float32), lfr_splice_cache, splice_idx
@staticmethod
def compute_frame_num(
--
Gitblit v1.9.1