From c367f1b819bb40e70e379f8f80b1d9fd67e47c79 Mon Sep 17 00:00:00 2001
From: zhuangzhong <zhuangzhong@corp.netease.com>
Date: 星期五, 07 六月 2024 16:15:16 +0800
Subject: [PATCH] add cif_wo_hidden_v1

---
 funasr/models/paraformer/cif_predictor.py |   24 ++++++++++++++++++------
 1 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index a6bfe65..7490310 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -661,14 +661,13 @@
     return torch.stack(list_ls, 0), fires
 
 
-def cif_v1(hidden, alphas, threshold):
+def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
+    batch_size, len_time = alphas.size()
+    device = alphas.device
+    dtype = alphas.dtype
 
-    device = hidden.device
-    dtype = hidden.dtype
-    batch_size, len_time, hidden_size = hidden.size()
     threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
 
-    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
     fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
 
     prefix_sum = torch.cumsum(alphas, dim=1)
@@ -682,7 +681,19 @@
     fire_idxs = dislocation_diff > 0
     fires[fire_idxs] = 1
     fires = fires + prefix_sum - prefix_sum_floor
+    if return_fire_idxs:
+        return fires, fire_idxs
+    return fires
 
+
+def cif_v1(hidden, alphas, threshold):
+    fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)
+
+    device = hidden.device
+    dtype = hidden.dtype
+    batch_size, len_time, hidden_size = hidden.size()
+    frames = torch.zeros(batch_size, len_time, hidden_size,
+                         dtype=dtype, device=device)
     prefix_sum_hidden = torch.cumsum(
         alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
     )
@@ -698,7 +709,8 @@
 
     remains = fires - torch.floor(fires)
     remain_frames = (
-        remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
+        remains[fire_idxs].unsqueeze(-1).tile((1,
+                                               hidden_size)) * hidden[fire_idxs]
     )
 
     shift_remain_frames = torch.roll(remain_frames, 1, dims=0)

--
Gitblit v1.9.1