From 45d7aa9004763684fb748ee17942ecba81042201 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 19 六月 2024 10:26:40 +0800
Subject: [PATCH] decoding
---
funasr/models/paraformer/cif_predictor.py | 40 ++++++++++++++--------------------------
1 files changed, 14 insertions(+), 26 deletions(-)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 05e283a..0856eed 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -494,6 +494,8 @@
token_num_floor = torch.floor(token_num)
return hidden, alphas, token_num_floor
+
+
@torch.jit.script
def cif_v1_export(hidden, alphas, threshold: float):
device = hidden.device
@@ -504,7 +506,7 @@
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
- prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme
+ prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -516,9 +518,7 @@
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
- prefix_sum_hidden = torch.cumsum(
- alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
- )
+ prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
@@ -530,25 +530,21 @@
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
- remain_frames = (
- remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
- )
+ remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
- max_label_len = alphas.sum(dim=-1)
- max_label_len = torch.floor(max_label_len).max().to(dtype=torch.int64)
+ max_label_len = batch_len.max()
- frame_fires = torch.zeros(
- batch_size, max_label_len, hidden_size, dtype=dtype, device=device
- )
+ frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires
+
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
@@ -671,7 +667,7 @@
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)
- prefix_sum = torch.cumsum(alphas, dim=1, dtype=torch.float64).to(torch.float32) # cumsum precision degradation cause wrong result in extreme
+ prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)
@@ -693,11 +689,8 @@
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
- frames = torch.zeros(batch_size, len_time, hidden_size,
- dtype=dtype, device=device)
- prefix_sum_hidden = torch.cumsum(
- alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
- )
+ frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
+ prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
@@ -709,21 +702,16 @@
shift_frames[shift_batch_idxs] = 0
remains = fires - torch.floor(fires)
- remain_frames = (
- remains[fire_idxs].unsqueeze(-1).tile((1,
- hidden_size)) * hidden[fire_idxs]
- )
+ remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
frames = frames - shift_frames + shift_remain_frames - remain_frames
- max_label_len = torch.round(alphas.sum(-1)).int().max() # torch.round to calculate the max length
+ max_label_len = batch_len.max()
- frame_fires = torch.zeros(
- batch_size, max_label_len, hidden_size, dtype=dtype, device=device
- )
+ frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
--
Gitblit v1.9.1