Tang Linjie
2024-11-30 8b1be8c3cba8987e9993619b46e59039ef3d6560
feat: speed up fbank's lfr (#2246)

Co-authored-by: linjie.tang <linjie.tang@sophgo.com>
1个文件已修改
51 ■■■■ 已修改文件
funasr/frontends/wav_frontend.py 51 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/frontends/wav_frontend.py
@@ -62,17 +62,16 @@
    left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
    inputs = torch.vstack((left_padding, inputs))
    T = T + (lfr_m - 1) // 2
    for i in range(T_lfr):
        if lfr_m <= T - i * lfr_n:
            LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
        else:  # process last LFR frame
            num_padding = lfr_m - (T - i * lfr_n)
            frame = (inputs[i * lfr_n :]).view(-1)
            for _ in range(num_padding):
                frame = torch.hstack((frame, inputs[-1]))
            LFR_inputs.append(frame)
    LFR_outputs = torch.vstack(LFR_inputs)
    return LFR_outputs.type(torch.float32)
    feat_dim = inputs.shape[-1]
    strides = (lfr_n * feat_dim, 1)
    sizes = (T_lfr, lfr_m * feat_dim)
    last_idx = (T - lfr_m) // lfr_n + 1
    num_padding = lfr_m - (T - last_idx * lfr_n)
    if num_padding > 0:
        num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
        inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
    LFR_outputs = inputs.as_strided(sizes, strides)
    return LFR_outputs.clone().type(torch.float32)
@tables.register("frontend_classes", "wav_frontend")
@@ -289,24 +288,24 @@
            np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
        )  # minus the right context: (lfr_m - 1) // 2
        splice_idx = T_lfr
        for i in range(T_lfr):
            if lfr_m <= T - i * lfr_n:
                LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
            else:  # process last LFR frame
        feat_dim = inputs.shape[-1]
        ori_inputs = inputs
        strides = (lfr_n * feat_dim, 1)
        sizes = (T_lfr, lfr_m * feat_dim)
        last_idx = (T - lfr_m) // lfr_n + 1
        num_padding = lfr_m - (T - last_idx * lfr_n)
                if is_final:
                    num_padding = lfr_m - (T - i * lfr_n)
                    frame = (inputs[i * lfr_n :]).view(-1)
                    for _ in range(num_padding):
                        frame = torch.hstack((frame, inputs[-1]))
                    LFR_inputs.append(frame)
            if num_padding > 0:
                num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx)
                inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding))
                else:
                    # update splice_idx and break the circle
                    splice_idx = i
                    break
            if num_padding > 0:
                sizes = (last_idx, lfr_m * feat_dim)
                splice_idx = last_idx
        splice_idx = min(T - 1, splice_idx * lfr_n)
        lfr_splice_cache = inputs[splice_idx:, :]
        LFR_outputs = torch.vstack(LFR_inputs)
        return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
        LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides)
        lfr_splice_cache = ori_inputs[splice_idx:, :]
        return LFR_outputs.clone().type(torch.float32), lfr_splice_cache, splice_idx
    @staticmethod
    def compute_frame_num(