From 20aa07268a7fafaaab7762b488615af32a0e82b4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 11 六月 2024 14:02:27 +0800
Subject: [PATCH] update with main (#1800)
---
funasr/models/paraformer/cif_predictor.py | 127 +++++++++++++++++++++++++++++++++++++++++-
1 files changed, 124 insertions(+), 3 deletions(-)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 8b1a9bb..7490310 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -80,7 +80,7 @@
hidden, alphas, token_num, mask=mask
)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
@@ -245,7 +245,7 @@
hidden, alphas, token_num, mask=None
)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
@@ -449,7 +449,7 @@
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
@@ -494,7 +494,60 @@
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
+@torch.jit.script
+def cif_v1_export(hidden, alphas, threshold: float):
+ device = hidden.device
+ dtype = hidden.dtype
+ batch_size, len_time, hidden_size = hidden.size()
+ threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+ frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+ fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+ prefix_sum = torch.cumsum(alphas, dim=1)
+ prefix_sum_floor = torch.floor(prefix_sum)
+ dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+ dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+ dislocation_prefix_sum_floor[:, 0] = 0
+ dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+ fire_idxs = dislocation_diff > 0
+ fires[fire_idxs] = 1
+ fires = fires + prefix_sum - prefix_sum_floor
+
+ prefix_sum_hidden = torch.cumsum(
+ alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
+ )
+
+ frames = prefix_sum_hidden[fire_idxs]
+ shift_frames = torch.roll(frames, 1, dims=0)
+
+ batch_len = fire_idxs.sum(1)
+ batch_idxs = torch.cumsum(batch_len, dim=0)
+ shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+ shift_batch_idxs[0] = 0
+ shift_frames[shift_batch_idxs] = 0
+
+ remains = fires - torch.floor(fires)
+ remain_frames = (
+ remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+ )
+
+ shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+ shift_remain_frames[shift_batch_idxs] = 0
+
+ frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+ max_label_len = batch_len.max()
+
+ frame_fires = torch.zeros(
+ batch_size, max_label_len, hidden_size, dtype=dtype, device=device
+ )
+ indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+ frame_fires_idxs = indices < batch_len.unsqueeze(1)
+ frame_fires[frame_fires_idxs] = frames
+ return frame_fires, fires
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
@@ -608,6 +661,74 @@
return torch.stack(list_ls, 0), fires
+def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
+ batch_size, len_time = alphas.size()
+ device = alphas.device
+ dtype = alphas.dtype
+
+ threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+
+ fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+ prefix_sum = torch.cumsum(alphas, dim=1)
+ prefix_sum_floor = torch.floor(prefix_sum)
+ dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+ dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+ dislocation_prefix_sum_floor[:, 0] = 0
+ dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+ fire_idxs = dislocation_diff > 0
+ fires[fire_idxs] = 1
+ fires = fires + prefix_sum - prefix_sum_floor
+ if return_fire_idxs:
+ return fires, fire_idxs
+ return fires
+
+
+def cif_v1(hidden, alphas, threshold):
+ fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
+
+ device = hidden.device
+ dtype = hidden.dtype
+ batch_size, len_time, hidden_size = hidden.size()
+ frames = torch.zeros(batch_size, len_time, hidden_size,
+ dtype=dtype, device=device)
+ prefix_sum_hidden = torch.cumsum(
+ alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
+ )
+
+ frames = prefix_sum_hidden[fire_idxs]
+ shift_frames = torch.roll(frames, 1, dims=0)
+
+ batch_len = fire_idxs.sum(1)
+ batch_idxs = torch.cumsum(batch_len, dim=0)
+ shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+ shift_batch_idxs[0] = 0
+ shift_frames[shift_batch_idxs] = 0
+
+ remains = fires - torch.floor(fires)
+ remain_frames = (
+ remains[fire_idxs].unsqueeze(-1).tile((1,
+ hidden_size)) * hidden[fire_idxs]
+ )
+
+ shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+ shift_remain_frames[shift_batch_idxs] = 0
+
+ frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+ max_label_len = batch_len.max()
+
+ frame_fires = torch.zeros(
+ batch_size, max_label_len, hidden_size, dtype=dtype, device=device
+ )
+ indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+ frame_fires_idxs = indices < batch_len.unsqueeze(1)
+ frame_fires[frame_fires_idxs] = frames
+ return frame_fires, fires
+
+
def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()
--
Gitblit v1.9.1