From 8b1be8c3cba8987e9993619b46e59039ef3d6560 Mon Sep 17 00:00:00 2001
From: Tang Linjie <tanglinjie@cug.edu.cn>
Date: 星期六, 30 十一月 2024 13:05:39 +0800
Subject: [PATCH] feat: speed up fbank's lfr (#2246)

---
 funasr/frontends/wav_frontend.py |   55 +++++++++++++++++++++++++++----------------------------
 1 files changed, 27 insertions(+), 28 deletions(-)

diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py
index 3324208..da23f9c 100644
--- a/funasr/frontends/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -62,17 +62,16 @@
     left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
     inputs = torch.vstack((left_padding, inputs))
     T = T + (lfr_m - 1) // 2
-    for i in range(T_lfr):
-        if lfr_m <= T - i * lfr_n:
-            LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
-        else:  # process last LFR frame
-            num_padding = lfr_m - (T - i * lfr_n)
-            frame = (inputs[i * lfr_n :]).view(-1)
-            for _ in range(num_padding):
-                frame = torch.hstack((frame, inputs[-1]))
-            LFR_inputs.append(frame)
-    LFR_outputs = torch.vstack(LFR_inputs)
-    return LFR_outputs.type(torch.float32)
+    feat_dim = inputs.shape[-1]
+    strides = (lfr_n * feat_dim, 1)
+    sizes = (T_lfr, lfr_m * feat_dim)
+    last_idx = (T - lfr_m) // lfr_n + 1
+    num_padding = lfr_m - (T - last_idx * lfr_n)
+    if num_padding > 0:
+        num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
+        inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
+    LFR_outputs = inputs.as_strided(sizes, strides)
+    return LFR_outputs.clone().type(torch.float32)
 
 
 @tables.register("frontend_classes", "wav_frontend")
@@ -289,24 +288,24 @@
             np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
         )  # minus the right context: (lfr_m - 1) // 2
         splice_idx = T_lfr
-        for i in range(T_lfr):
-            if lfr_m <= T - i * lfr_n:
-                LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
-            else:  # process last LFR frame
-                if is_final:
-                    num_padding = lfr_m - (T - i * lfr_n)
-                    frame = (inputs[i * lfr_n :]).view(-1)
-                    for _ in range(num_padding):
-                        frame = torch.hstack((frame, inputs[-1]))
-                    LFR_inputs.append(frame)
-                else:
-                    # update splice_idx and break the circle
-                    splice_idx = i
-                    break
+        feat_dim = inputs.shape[-1]
+        ori_inputs = inputs
+        strides = (lfr_n * feat_dim, 1)
+        sizes = (T_lfr, lfr_m * feat_dim)
+        last_idx = (T - lfr_m) // lfr_n + 1
+        num_padding = lfr_m - (T - last_idx * lfr_n)
+        if is_final:
+            if num_padding > 0:
+                num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
+                inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
+        else:
+            if num_padding > 0:
+                sizes = (last_idx, lfr_m * feat_dim)
+                splice_idx = last_idx
         splice_idx = min(T - 1, splice_idx * lfr_n)
-        lfr_splice_cache = inputs[splice_idx:, :]
-        LFR_outputs = torch.vstack(LFR_inputs)
-        return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
+        LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
+        lfr_splice_cache = ori_inputs[splice_idx:, :]
+        return LFR_outputs.clone().type(torch.float32), lfr_splice_cache, splice_idx
 
     @staticmethod
     def compute_frame_num(

--
Gitblit v1.9.1