From 145022c438c8158e6bffdc5ff66fdd28468a0a18 Mon Sep 17 00:00:00 2001
From: zhong zhuang <zhuangz@lamda.nju.edu.cn>
Date: 星期五, 14 六月 2024 20:33:16 +0800
Subject: [PATCH] Update cif_predictor.py (#1811)

---
 funasr/models/paraformer/cif_predictor.py |    9 +++++----
 1 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 7490310..05e283a 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -504,7 +504,7 @@
     frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
     fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
 
-    prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme 
     prefix_sum_floor = torch.floor(prefix_sum)
     dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
     dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -539,7 +539,8 @@
 
     frames = frames - shift_frames + shift_remain_frames - remain_frames
 
-    max_label_len = batch_len.max()
+    max_label_len = alphas.sum(dim=-1)
+    max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
 
     frame_fires = torch.zeros(
         batch_size, max_label_len, hidden_size, dtype=dtype, device=device
@@ -670,7 +671,7 @@
 
     fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
 
-    prefix_sum = torch.cumsum(alphas, dim=1)
+    prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme 
     prefix_sum_floor = torch.floor(prefix_sum)
     dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
     dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -718,7 +719,7 @@
 
     frames = frames - shift_frames + shift_remain_frames - remain_frames
 
-    max_label_len = batch_len.max()
+    max_label_len = torch.round(alphas.sum(-1)).int().max() # torch.round to calculate the max length
 
     frame_fires = torch.zeros(
         batch_size, max_label_len, hidden_size, dtype=dtype, device=device

--
Gitblit v1.9.1