From ae49b2a8e1bc676e6014d8a12ebeec947b655e3e Mon Sep 17 00:00:00 2001
From: 莫拉古 <61447879+yechaoying@users.noreply.github.com>
Date: 星期五, 29 十一月 2024 09:55:43 +0800
Subject: [PATCH] 变量名写错了 (#2249)
---
funasr/models/paraformer/cif_predictor.py | 33 +++++++++++++++++++++++++--------
1 files changed, 25 insertions(+), 8 deletions(-)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 0856eed..24145cd 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -80,7 +80,7 @@
hidden, alphas, token_num, mask=mask
)
- acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
@@ -245,7 +245,7 @@
hidden, alphas, token_num, mask=None
)
- acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
@@ -449,7 +449,7 @@
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
+ acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
return acoustic_embeds, token_num, alphas, cif_peak
@@ -506,7 +506,10 @@
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
- prefix_sum = torch.cumsum(alphas, dim=1)
+ # prefix_sum = torch.cumsum(alphas, dim=1)
+ prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
+ torch.float32
+ ) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -518,8 +521,8 @@
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
+ # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
-
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
@@ -530,6 +533,7 @@
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
+ # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
@@ -537,8 +541,11 @@
frames = frames - shift_frames + shift_remain_frames - remain_frames
- max_label_len = batch_len.max()
+ # max_label_len = batch_len.max()
+ max_label_len = alphas.sum(dim=-1)
+ max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
+ # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
@@ -667,7 +674,10 @@
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
- prefix_sum = torch.cumsum(alphas, dim=1)
+ # prefix_sum = torch.cumsum(alphas, dim=1)
+ prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(
+ torch.float32
+ ) # cumsum precision degradation cause wrong result in extreme
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -689,6 +699,8 @@
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
+ # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+ # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
@@ -702,6 +714,7 @@
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
+ # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
@@ -709,8 +722,12 @@
frames = frames - shift_frames + shift_remain_frames - remain_frames
- max_label_len = batch_len.max()
+ # max_label_len = batch_len.max()
+ max_label_len = (
+ torch.round(alphas.sum(-1)).int().max()
+ ) # torch.round to calculate the max length
+ # frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
--
Gitblit v1.9.1