From 93f9a424f2bc0607d31ef66b0c7c58dfac15ce25 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 10:07:31 +0800
Subject: [PATCH] fixbug for cif

---
 funasr/models/paraformer/cif_predictor.py |   27 ++++++++++++++++++++++-----
 1 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 83ca464..535131f 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -506,7 +506,10 @@
     frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
     fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
 
-    prefix_sum = torch.cumsum(alphas, dim=1)
+    # prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
+        torch.float32
+    )  # cumsum precision degradation cause wrong result in extreme
     prefix_sum_floor = torch.floor(prefix_sum)
     dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
     dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -518,8 +521,8 @@
     fires[fire_idxs] = 1
     fires = fires + prefix_sum - prefix_sum_floor
 
+    # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
     prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
-
     frames = prefix_sum_hidden[fire_idxs]
     shift_frames = torch.roll(frames, 1, dims=0)
 
@@ -530,6 +533,7 @@
     shift_frames[shift_batch_idxs] = 0
 
     remains = fires - torch.floor(fires)
+    # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
     remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
 
     shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
@@ -537,8 +541,11 @@
 
     frames = frames - shift_frames + shift_remain_frames - remain_frames
 
-    max_label_len = batch_len.max()
+    # max_label_len = batch_len.max()
+    max_label_len = alphas.sum(dim=-1)
+    max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
 
+    # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
     frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
     indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
     frame_fires_idxs = indices < batch_len.unsqueeze(1)
@@ -667,7 +674,10 @@
 
     fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
 
-    prefix_sum = torch.cumsum(alphas, dim=1)
+    # prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
+        torch.float32
+    )  # cumsum precision degradation cause wrong result in extreme
     prefix_sum_floor = torch.floor(prefix_sum)
     dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
     dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -689,6 +699,8 @@
     device = hidden.device
     dtype = hidden.dtype
     batch_size, len_time, hidden_size = hidden.size()
+    # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+    # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
     frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
     prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
 
@@ -702,6 +714,7 @@
     shift_frames[shift_batch_idxs] = 0
 
     remains = fires - torch.floor(fires)
+    # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
     remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
 
     shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
@@ -709,8 +722,12 @@
 
     frames = frames - shift_frames + shift_remain_frames - remain_frames
 
-    max_label_len = batch_len.max()
+    # max_label_len = batch_len.max()
+    max_label_len = (
+        torch.round(alphas.sum(-1)).int().max()
+    )  # torch.round to calculate the max length
 
+    # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
     frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
     indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
     frame_fires_idxs = indices < batch_len.unsqueeze(1)

--
Gitblit v1.9.1