From 1c52e364aa987dd03b4e9f52e0b725b6f335863b Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 21 六月 2024 11:19:19 +0800
Subject: [PATCH] version checker

---
 funasr/models/paraformer/cif_predictor.py |   31 ++++++++++---------------------
 1 files changed, 10 insertions(+), 21 deletions(-)

diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 7490310..0856eed 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -494,6 +494,8 @@
         token_num_floor = torch.floor(token_num)
 
         return hidden, alphas, token_num_floor
+
+
 @torch.jit.script
 def cif_v1_export(hidden, alphas, threshold: float):
     device = hidden.device
@@ -516,9 +518,7 @@
     fires[fire_idxs] = 1
     fires = fires + prefix_sum - prefix_sum_floor
 
-    prefix_sum_hidden = torch.cumsum(
-        alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
-    )
+    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
 
     frames = prefix_sum_hidden[fire_idxs]
     shift_frames = torch.roll(frames, 1, dims=0)
@@ -530,9 +530,7 @@
     shift_frames[shift_batch_idxs] = 0
 
     remains = fires - torch.floor(fires)
-    remain_frames = (
-        remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
-    )
+    remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
 
     shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
     shift_remain_frames[shift_batch_idxs] = 0
@@ -541,13 +539,12 @@
 
     max_label_len = batch_len.max()
 
-    frame_fires = torch.zeros(
-        batch_size, max_label_len, hidden_size, dtype=dtype, device=device
-    )
+    frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
     indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
     frame_fires_idxs = indices < batch_len.unsqueeze(1)
     frame_fires[frame_fires_idxs] = frames
     return frame_fires, fires
+
 
 @torch.jit.script
 def cif_export(hidden, alphas, threshold: float):
@@ -692,11 +689,8 @@
     device = hidden.device
     dtype = hidden.dtype
     batch_size, len_time, hidden_size = hidden.size()
-    frames = torch.zeros(batch_size, len_time, hidden_size,
-                         dtype=dtype, device=device)
-    prefix_sum_hidden = torch.cumsum(
-        alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
-    )
+    frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+    prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
 
     frames = prefix_sum_hidden[fire_idxs]
     shift_frames = torch.roll(frames, 1, dims=0)
@@ -708,10 +702,7 @@
     shift_frames[shift_batch_idxs] = 0
 
     remains = fires - torch.floor(fires)
-    remain_frames = (
-        remains[fire_idxs].unsqueeze(-1).tile((1,
-                                               hidden_size)) * hidden[fire_idxs]
-    )
+    remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
 
     shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
     shift_remain_frames[shift_batch_idxs] = 0
@@ -720,9 +711,7 @@
 
     max_label_len = batch_len.max()
 
-    frame_fires = torch.zeros(
-        batch_size, max_label_len, hidden_size, dtype=dtype, device=device
-    )
+    frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
     indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
     frame_fires_idxs = indices < batch_len.unsqueeze(1)
     frame_fires[frame_fires_idxs] = frames

--
Gitblit v1.9.1