| | |
| | | left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1) |
| | | inputs = torch.vstack((left_padding, inputs)) |
| | | T = T + (lfr_m - 1) // 2 |
| | | for i in range(T_lfr): |
| | | if lfr_m <= T - i * lfr_n: |
| | | LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1)) |
| | | else: # process last LFR frame |
| | | num_padding = lfr_m - (T - i * lfr_n) |
| | | frame = (inputs[i * lfr_n :]).view(-1) |
| | | for _ in range(num_padding): |
| | | frame = torch.hstack((frame, inputs[-1])) |
| | | LFR_inputs.append(frame) |
| | | LFR_outputs = torch.vstack(LFR_inputs) |
| | | return LFR_outputs.type(torch.float32) |
| | | feat_dim = inputs.shape[-1] |
| | | strides = (lfr_n * feat_dim, 1) |
| | | sizes = (T_lfr, lfr_m * feat_dim) |
| | | last_idx = (T - lfr_m) // lfr_n + 1 |
| | | num_padding = lfr_m - (T - last_idx * lfr_n) |
| | | if num_padding > 0: |
| | | num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx) |
| | | inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding)) |
| | | LFR_outputs = inputs.as_strided(sizes, strides) |
| | | return LFR_outputs.clone().type(torch.float32) |
| | | |
| | | |
| | | @tables.register("frontend_classes", "wav_frontend") |
| | |
| | | np.ceil((T - (lfr_m - 1) // 2) / lfr_n) |
| | | ) # minus the right context: (lfr_m - 1) // 2 |
| | | splice_idx = T_lfr |
| | | for i in range(T_lfr): |
| | | if lfr_m <= T - i * lfr_n: |
| | | LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1)) |
| | | else: # process last LFR frame |
| | | feat_dim = inputs.shape[-1] |
| | | ori_inputs = inputs |
| | | strides = (lfr_n * feat_dim, 1) |
| | | sizes = (T_lfr, lfr_m * feat_dim) |
| | | last_idx = (T - lfr_m) // lfr_n + 1 |
| | | num_padding = lfr_m - (T - last_idx * lfr_n) |
| | | if is_final: |
| | | num_padding = lfr_m - (T - i * lfr_n) |
| | | frame = (inputs[i * lfr_n :]).view(-1) |
| | | for _ in range(num_padding): |
| | | frame = torch.hstack((frame, inputs[-1])) |
| | | LFR_inputs.append(frame) |
| | | if num_padding > 0: |
| | | num_padding = (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) / 2 * (T_lfr - last_idx) |
| | | inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding)) |
| | | else: |
| | | # update splice_idx and break the circle |
| | | splice_idx = i |
| | | break |
| | | if num_padding > 0: |
| | | sizes = (last_idx, lfr_m * feat_dim) |
| | | splice_idx = last_idx |
| | | splice_idx = min(T - 1, splice_idx * lfr_n) |
| | | lfr_splice_cache = inputs[splice_idx:, :] |
| | | LFR_outputs = torch.vstack(LFR_inputs) |
| | | return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx |
| | | LFR_outputs = inputs[:splice_idx].as_strided(sizes, strides) |
| | | lfr_splice_cache = ori_inputs[splice_idx:, :] |
| | | return LFR_outputs.clone().type(torch.float32), lfr_splice_cache, splice_idx |
| | | |
| | | @staticmethod |
| | | def compute_frame_num( |