import torch from torch import nn import logging import numpy as np def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): if maxlen is None: maxlen = lengths.max() row_vector = torch.arange(0, maxlen, 1).to(lengths.device) matrix = torch.unsqueeze(lengths, dim=-1) mask = row_vector < matrix mask = mask.detach() return mask.type(dtype).to(device) if device is not None else mask.type(dtype) def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): if length_dim == 0: raise ValueError("length_dim cannot be 0: {}".format(length_dim)) if not isinstance(lengths, list): lengths = lengths.tolist() bs = int(len(lengths)) if maxlen is None: if xs is None: maxlen = int(max(lengths)) else: maxlen = xs.size(length_dim) else: assert xs is None assert maxlen >= int(max(lengths)) seq_range = torch.arange(0, maxlen, dtype=torch.int64) seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) mask = seq_range_expand >= seq_length_expand if xs is not None: assert xs.size(0) == bs, (xs.size(0), bs) if length_dim < 0: length_dim = xs.dim() + length_dim # ind = (:, None, ..., None, :, , None, ..., None) ind = tuple( slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) ) mask = mask[ind].expand_as(xs).to(xs.device) return mask class CifPredictorV2(nn.Module): def __init__(self, idim: int, l_order: int, r_order: int, threshold: float = 1.0, dropout: float = 0.1, smooth_factor: float = 1.0, noise_threshold: float = 0, tail_threshold: float = 0.0, ): super(CifPredictorV2, self).__init__() self.pad = nn.ConstantPad1d((l_order, r_order), 0.0) self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1) self.cif_output = nn.Linear(idim, 1) self.dropout = torch.nn.Dropout(p=dropout) self.threshold = threshold self.smooth_factor = smooth_factor self.noise_threshold = noise_threshold self.tail_threshold = tail_threshold def forward(self, hidden: torch.Tensor, mask: torch.Tensor, ): h = hidden context = h.transpose(1, 2) queries = self.pad(context) output = torch.relu(self.cif_conv1d(queries)) output = output.transpose(1, 2) output = self.cif_output(output) alphas = torch.sigmoid(output) alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) mask = mask.transpose(-1, -2).float() alphas = alphas * mask alphas = alphas.squeeze(-1) token_num = alphas.sum(-1) acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) return acoustic_embeds, token_num, alphas, cif_peak def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): b, t, d = hidden.size() tail_threshold = self.tail_threshold zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device) ones_t = torch.ones_like(zeros_t) mask_1 = torch.cat([mask, zeros_t], dim=1) mask_2 = torch.cat([ones_t, mask], dim=1) mask = mask_2 - mask_1 tail_threshold = mask * tail_threshold alphas = torch.cat([alphas, tail_threshold], dim=1) zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device) hidden = torch.cat([hidden, zeros], dim=1) token_num = alphas.sum(dim=-1) token_num_floor = torch.floor(token_num) return hidden, alphas, token_num_floor @torch.jit.script def cif(hidden, alphas, threshold: float): batch_size, len_time, hidden_size = hidden.size() threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) # loop varss integrate = torch.zeros([batch_size], device=hidden.device) frame = torch.zeros([batch_size, hidden_size], device=hidden.device) # intermediate vars along time list_fires = [] list_frames = [] for t in range(len_time): alpha = alphas[:, t] distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate integrate += alpha list_fires.append(integrate) fire_place = integrate >= threshold integrate = torch.where(fire_place, integrate - torch.ones([batch_size], device=hidden.device), integrate) cur = torch.where(fire_place, distribution_completion, alpha) remainds = alpha - cur frame += cur[:, None] * hidden[:, t, :] list_frames.append(frame) frame = torch.where(fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame) fires = torch.stack(list_fires, 1) frames = torch.stack(list_frames, 1) list_ls = [] len_labels = torch.round(alphas.sum(-1)).int() max_label_len = len_labels.max() for b in range(batch_size): fire = fires[b, :] l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device) list_ls.append(torch.cat([l, pad_l], 0)) return torch.stack(list_ls, 0), fires def CifPredictorV2_test(): x = torch.rand([2, 21, 2]) x_len = torch.IntTensor([6, 21]) mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype) x = x * mask[:, :, None] predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1)) # cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :]) predictor_scripts.save('test.pt') loaded = torch.jit.load('test.pt') cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :]) # print(cif_output) print(predictor_scripts.code) # predictor = CifPredictorV2(2, 1, 1) # cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :]) print(cif_output) def CifPredictorV2_export_test(): x = torch.rand([2, 21, 2]) x_len = torch.IntTensor([6, 21]) mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype) x = x * mask[:, :, None] # predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1)) # cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :]) predictor = CifPredictorV2(2, 1, 1) predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :])) predictor_trace.save('test_trace.pt') loaded = torch.jit.load('test_trace.pt') x = torch.rand([3, 30, 2]) x_len = torch.IntTensor([6, 20, 30]) mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype) x = x * mask[:, :, None] cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :]) print(cif_output) # print(predictor_trace.code) # predictor = CifPredictorV2(2, 1, 1) # cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :]) # print(cif_output) if __name__ == '__main__': # CifPredictorV2_test() CifPredictorV2_export_test()