From 8cc5bbf99a59694228aafcbe8712e09b9a4cb26b Mon Sep 17 00:00:00 2001 From: zhifu gao <zhifu.gzf@alibaba-inc.com> Date: 星期一, 27 二月 2023 17:01:48 +0800 Subject: [PATCH] Merge pull request #159 from alibaba-damo-academy/dev_dzh --- funasr/export/models/predictor/cif.py | 81 ++++++++++++++++++++++++++++++++++------ 1 files changed, 69 insertions(+), 12 deletions(-) diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index 5518cb8..cb26862 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -16,6 +16,11 @@ return mask.type(dtype).to(device) if device is not None else mask.type(dtype) +def sequence_mask_scripts(lengths, maxlen:int): + row_vector = torch.arange(0, maxlen, 1).type(lengths.dtype).to(lengths.device) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + return mask.type(torch.float32).to(lengths.device) class CifPredictorV2(nn.Module): def __init__(self, model): @@ -71,28 +76,76 @@ return hidden, alphas, token_num_floor + +# @torch.jit.script +# def cif(hidden, alphas, threshold: float): +# batch_size, len_time, hidden_size = hidden.size() +# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) +# +# # loop varss +# integrate = torch.zeros([batch_size], device=hidden.device) +# frame = torch.zeros([batch_size, hidden_size], device=hidden.device) +# # intermediate vars along time +# list_fires = [] +# list_frames = [] +# +# for t in range(len_time): +# alpha = alphas[:, t] +# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate +# +# integrate += alpha +# list_fires.append(integrate) +# +# fire_place = integrate >= threshold +# integrate = torch.where(fire_place, +# integrate - torch.ones([batch_size], device=hidden.device), +# integrate) +# cur = torch.where(fire_place, +# distribution_completion, +# alpha) +# remainds = alpha - cur +# +# frame += cur[:, None] * hidden[:, t, :] +# list_frames.append(frame) +# frame = torch.where(fire_place[:, None].repeat(1, hidden_size), +# remainds[:, None] * hidden[:, t, :], +# frame) +# +# fires = torch.stack(list_fires, 1) +# frames = torch.stack(list_frames, 1) +# list_ls = [] +# len_labels = torch.floor(alphas.sum(-1)).int() +# max_label_len = len_labels.max() +# for b in range(batch_size): +# fire = fires[b, :] +# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) +# pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device) +# list_ls.append(torch.cat([l, pad_l], 0)) +# return torch.stack(list_ls, 0), fires + + @torch.jit.script def cif(hidden, alphas, threshold: float): batch_size, len_time, hidden_size = hidden.size() threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) # loop varss - integrate = torch.zeros([batch_size], device=hidden.device) - frame = torch.zeros([batch_size, hidden_size], device=hidden.device) + integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device) + frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device) # intermediate vars along time list_fires = [] list_frames = [] for t in range(len_time): alpha = alphas[:, t] - distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate + distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate integrate += alpha list_fires.append(integrate) fire_place = integrate >= threshold integrate = torch.where(fire_place, - integrate - torch.ones([batch_size], device=hidden.device), + integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device), integrate) cur = torch.where(fire_place, distribution_completion, @@ -107,12 +160,16 @@ fires = torch.stack(list_fires, 1) frames = torch.stack(list_frames, 1) - list_ls = [] - len_labels = torch.round(alphas.sum(-1)).int() - max_label_len = len_labels.max() + + fire_idxs = fires >= threshold + frame_fires = torch.zeros_like(hidden) + max_label_len = frames[0, fire_idxs[0]].size(0) for b in range(batch_size): - fire = fires[b, :] - l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) - pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device) - list_ls.append(torch.cat([l, pad_l], 0)) - return torch.stack(list_ls, 0), fires + frame_fire = frames[b, fire_idxs[b]] + frame_len = frame_fire.size(0) + frame_fires[b, :frame_len, :] = frame_fire + + if frame_len >= max_label_len: + max_label_len = frame_len + frame_fires = frame_fires[:, :max_label_len, :] + return frame_fires, fires -- Gitblit v1.9.1