From aa3fe1a353bde71d106755d030d9e5300fbde328 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 22 七月 2024 19:02:15 +0800
Subject: [PATCH] python runtime

---
 funasr/models/paraformer/cif_predictor.py |  127 ++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 127 insertions(+), 0 deletions(-)

diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 8b1a9bb..24145cd 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -497,6 +497,63 @@
 
 
 @torch.jit.script
+def cif_v1_export(hidden, alphas, threshold: float):
+    device = hidden.device
+    dtype = hidden.dtype
+    batch_size, len_time, hidden_size = hidden.size()
+    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+
+    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+    # prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
+        torch.float32
+    )  # cumsum precision degradation cause wrong result in extreme
+    prefix_sum_floor = torch.floor(prefix_sum)
+    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+    dislocation_prefix_sum_floor[:, 0] = 0
+    dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+    fire_idxs = dislocation_diff > 0
+    fires[fire_idxs] = 1
+    fires = fires + prefix_sum - prefix_sum_floor
+
+    # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
+    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
+    frames = prefix_sum_hidden[fire_idxs]
+    shift_frames = torch.roll(frames, 1, dims=0)
+
+    batch_len = fire_idxs.sum(1)
+    batch_idxs = torch.cumsum(batch_len, dim=0)
+    shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+    shift_batch_idxs[0] = 0
+    shift_frames[shift_batch_idxs] = 0
+
+    remains = fires - torch.floor(fires)
+    # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+    remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+
+    shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+    shift_remain_frames[shift_batch_idxs] = 0
+
+    frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+    # max_label_len = batch_len.max()
+    max_label_len = alphas.sum(dim=-1)
+    max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
+
+    # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
+    frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
+    indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+    frame_fires_idxs = indices < batch_len.unsqueeze(1)
+    frame_fires[frame_fires_idxs] = frames
+    return frame_fires, fires
+
+
+@torch.jit.script
 def cif_export(hidden, alphas, threshold: float):
     batch_size, len_time, hidden_size = hidden.size()
     threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
@@ -608,6 +665,76 @@
     return torch.stack(list_ls, 0), fires
 
 
+def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
+    batch_size, len_time = alphas.size()
+    device = alphas.device
+    dtype = alphas.dtype
+
+    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+
+    fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
+
+    # prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
+        torch.float32
+    )  # cumsum precision degradation cause wrong result in extreme
+    prefix_sum_floor = torch.floor(prefix_sum)
+    dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
+    dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
+
+    dislocation_prefix_sum_floor[:, 0] = 0
+    dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor
+
+    fire_idxs = dislocation_diff > 0
+    fires[fire_idxs] = 1
+    fires = fires + prefix_sum - prefix_sum_floor
+    if return_fire_idxs:
+        return fires, fire_idxs
+    return fires
+
+
+def cif_v1(hidden, alphas, threshold):
+    fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
+
+    device = hidden.device
+    dtype = hidden.dtype
+    batch_size, len_time, hidden_size = hidden.size()
+    # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+    # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
+    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
+
+    frames = prefix_sum_hidden[fire_idxs]
+    shift_frames = torch.roll(frames, 1, dims=0)
+
+    batch_len = fire_idxs.sum(1)
+    batch_idxs = torch.cumsum(batch_len, dim=0)
+    shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
+    shift_batch_idxs[0] = 0
+    shift_frames[shift_batch_idxs] = 0
+
+    remains = fires - torch.floor(fires)
+    # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+    remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+
+    shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
+    shift_remain_frames[shift_batch_idxs] = 0
+
+    frames = frames - shift_frames + shift_remain_frames - remain_frames
+
+    # max_label_len = batch_len.max()
+    max_label_len = (
+        torch.round(alphas.sum(-1)).int().max()
+    )  # torch.round to calculate the max length
+
+    # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
+    frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
+    indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
+    frame_fires_idxs = indices < batch_len.unsqueeze(1)
+    frame_fires[frame_fires_idxs] = frames
+    return frame_fires, fires
+
+
 def cif_wo_hidden(alphas, threshold):
     batch_size, len_time = alphas.size()
 

--
Gitblit v1.9.1