| | |
| | | from funasr.models.transformer.utils.nets_utils import to_device |
| | | |
| | | |
| | | def _apply_attention_constraint( |
| | | e, last_attended_idx, backward_window=1, forward_window=3 |
| | | ): |
| | | def _apply_attention_constraint(e, last_attended_idx, backward_window=1, forward_window=3): |
| | | """Apply monotonic attention constraint. |
| | | |
| | | This function apply the monotonic attention constraint |
| | |
| | | mask = 1.0 - make_pad_mask(enc_hs_len).float() |
| | | att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1) |
| | | att_prev = att_prev.to(self.enc_h) |
| | | self.c = torch.sum( |
| | | self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1 |
| | | ) |
| | | self.c = torch.sum(self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1) |
| | | |
| | | return self.c, att_prev |
| | | |
| | |
| | | dec_z = dec_z.view(batch, self.dunits) |
| | | |
| | | e = torch.sum( |
| | | self.pre_compute_enc_h |
| | | * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim), |
| | | self.pre_compute_enc_h * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim), |
| | | dim=2, |
| | | ) # utt x frame |
| | | |
| | |
| | | and not store pre_compute_enc_h |
| | | """ |
| | | |
| | | def __init__( |
| | | self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False |
| | | ): |
| | | def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False): |
| | | super(AttLoc, self).__init__() |
| | | self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
| | | self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
| | |
| | | # initialize attention weight with uniform dist. |
| | | if att_prev is None: |
| | | # if no bias, 0 0-pad goes 0 |
| | | att_prev = 1.0 - make_pad_mask(enc_hs_len).to( |
| | | device=dec_z.device, dtype=dec_z.dtype |
| | | ) |
| | | att_prev = 1.0 - make_pad_mask(enc_hs_len).to(device=dec_z.device, dtype=dec_z.dtype) |
| | | att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) |
| | | |
| | | # att_prev: utt x frame -> utt x 1 x 1 x frame |
| | |
| | | |
| | | # dot with gvec |
| | | # utt x frame x att_dim -> utt x frame |
| | | e = self.gvec( |
| | | torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
| | | ).squeeze(2) |
| | | e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) |
| | | |
| | | # NOTE: consider zero padding when compute w. |
| | | if self.mask is None: |
| | |
| | | |
| | | # apply monotonic attention constraint (mainly for TTS) |
| | | if last_attended_idx is not None: |
| | | e = _apply_attention_constraint( |
| | | e, last_attended_idx, backward_window, forward_window |
| | | ) |
| | | e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window) |
| | | |
| | | w = F.softmax(scaling * e, dim=1) |
| | | |
| | |
| | | # initialize attention weight with uniform dist. |
| | | if att_prev_list is None: |
| | | # if no bias, 0 0-pad goes 0 |
| | | att_prev_list = to_device( |
| | | enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()) |
| | | ) |
| | | att_prev_list = [ |
| | | att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1) |
| | | ] |
| | | att_prev_list = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) |
| | | att_prev_list = [att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)] |
| | | |
| | | # att_prev_list: L' * [B x T] => cov_vec B x T |
| | | cov_vec = sum(att_prev_list) |
| | |
| | | |
| | | # dot with gvec |
| | | # utt x frame x att_dim -> utt x frame |
| | | e = self.gvec( |
| | | torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled) |
| | | ).squeeze(2) |
| | | e = self.gvec(torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) |
| | | |
| | | # NOTE consider zero padding when compute w. |
| | | if self.mask is None: |
| | |
| | | flag to swith on mode of hierarchical attention and not store pre_compute_enc_h |
| | | """ |
| | | |
| | | def __init__( |
| | | self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False |
| | | ): |
| | | def __init__(self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False): |
| | | super(AttLoc2D, self).__init__() |
| | | self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
| | | self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
| | |
| | | |
| | | # dot with gvec |
| | | # utt x frame x att_dim -> utt x frame |
| | | e = self.gvec( |
| | | torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
| | | ).squeeze(2) |
| | | e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) |
| | | |
| | | # NOTE consider zero padding when compute w. |
| | | if self.mask is None: |
| | |
| | | flag to swith on mode of hierarchical attention and not store pre_compute_enc_h |
| | | """ |
| | | |
| | | def __init__( |
| | | self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False |
| | | ): |
| | | def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False): |
| | | super(AttLocRec, self).__init__() |
| | | self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
| | | self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
| | |
| | | flag to swith on mode of hierarchical attention and not store pre_compute_enc_h |
| | | """ |
| | | |
| | | def __init__( |
| | | self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False |
| | | ): |
| | | def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False): |
| | | super(AttCovLoc, self).__init__() |
| | | self.mlp_enc = torch.nn.Linear(eprojs, att_dim) |
| | | self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) |
| | |
| | | if att_prev_list is None: |
| | | # if no bias, 0 0-pad goes 0 |
| | | mask = 1.0 - make_pad_mask(enc_hs_len).float() |
| | | att_prev_list = [ |
| | | to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) |
| | | ] |
| | | att_prev_list = [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))] |
| | | |
| | | # att_prev_list: L' * [B x T] => cov_vec B x T |
| | | cov_vec = sum(att_prev_list) |
| | |
| | | |
| | | # dot with gvec |
| | | # utt x frame x att_dim -> utt x frame |
| | | e = self.gvec( |
| | | torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
| | | ).squeeze(2) |
| | | e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) |
| | | |
| | | # NOTE consider zero padding when compute w. |
| | | if self.mask is None: |
| | |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_k = [ |
| | | torch.tanh(self.mlp_k[h](self.enc_h)) |
| | | for h in six.moves.range(self.aheads) |
| | | torch.tanh(self.mlp_k[h](self.enc_h)) for h in six.moves.range(self.aheads) |
| | | ] |
| | | |
| | | if self.pre_compute_v is None or self.han_mode: |
| | | self.enc_h = enc_hs_pad # utt x frame x hdim |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_v = [ |
| | | self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
| | | ] |
| | | self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)] |
| | | |
| | | if dec_z is None: |
| | | dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
| | |
| | | # weighted sum over flames |
| | | # utt x hdim |
| | | # NOTE use bmm instead of sum(*) |
| | | c += [ |
| | | torch.sum( |
| | | self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
| | | ) |
| | | ] |
| | | c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)] |
| | | |
| | | # concat all of c |
| | | c = self.mlp_o(torch.cat(c, dim=1)) |
| | |
| | | self.enc_h = enc_hs_pad # utt x frame x hdim |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_k = [ |
| | | self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) |
| | | ] |
| | | self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)] |
| | | |
| | | if self.pre_compute_v is None or self.han_mode: |
| | | self.enc_h = enc_hs_pad # utt x frame x hdim |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_v = [ |
| | | self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
| | | ] |
| | | self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)] |
| | | |
| | | if dec_z is None: |
| | | dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
| | |
| | | for h in six.moves.range(self.aheads): |
| | | e = self.gvec[h]( |
| | | torch.tanh( |
| | | self.pre_compute_k[h] |
| | | + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) |
| | | self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) |
| | | ) |
| | | ).squeeze(2) |
| | | |
| | |
| | | # weighted sum over flames |
| | | # utt x hdim |
| | | # NOTE use bmm instead of sum(*) |
| | | c += [ |
| | | torch.sum( |
| | | self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
| | | ) |
| | | ] |
| | | c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)] |
| | | |
| | | # concat all of c |
| | | c = self.mlp_o(torch.cat(c, dim=1)) |
| | |
| | | self.enc_h = enc_hs_pad # utt x frame x hdim |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_k = [ |
| | | self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) |
| | | ] |
| | | self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)] |
| | | |
| | | if self.pre_compute_v is None or self.han_mode: |
| | | self.enc_h = enc_hs_pad # utt x frame x hdim |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_v = [ |
| | | self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
| | | ] |
| | | self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)] |
| | | |
| | | if dec_z is None: |
| | | dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
| | |
| | | for _ in six.moves.range(self.aheads): |
| | | # if no bias, 0 0-pad goes 0 |
| | | mask = 1.0 - make_pad_mask(enc_hs_len).float() |
| | | att_prev += [ |
| | | to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) |
| | | ] |
| | | att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))] |
| | | |
| | | c = [] |
| | | w = [] |
| | |
| | | # weighted sum over flames |
| | | # utt x hdim |
| | | # NOTE use bmm instead of sum(*) |
| | | c += [ |
| | | torch.sum( |
| | | self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
| | | ) |
| | | ] |
| | | c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)] |
| | | |
| | | # concat all of c |
| | | c = self.mlp_o(torch.cat(c, dim=1)) |
| | |
| | | self.enc_h = enc_hs_pad # utt x frame x hdim |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_k = [ |
| | | self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) |
| | | ] |
| | | self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)] |
| | | |
| | | if self.pre_compute_v is None or self.han_mode: |
| | | self.enc_h = enc_hs_pad # utt x frame x hdim |
| | | self.h_length = self.enc_h.size(1) |
| | | # utt x frame x att_dim |
| | | self.pre_compute_v = [ |
| | | self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) |
| | | ] |
| | | self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)] |
| | | |
| | | if dec_z is None: |
| | | dec_z = enc_hs_pad.new_zeros(batch, self.dunits) |
| | |
| | | for _ in six.moves.range(self.aheads): |
| | | # if no bias, 0 0-pad goes 0 |
| | | mask = 1.0 - make_pad_mask(enc_hs_len).float() |
| | | att_prev += [ |
| | | to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) |
| | | ] |
| | | att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))] |
| | | |
| | | c = [] |
| | | w = [] |
| | |
| | | # weighted sum over flames |
| | | # utt x hdim |
| | | # NOTE use bmm instead of sum(*) |
| | | c += [ |
| | | torch.sum( |
| | | self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 |
| | | ) |
| | | ] |
| | | c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)] |
| | | |
| | | # concat all of c |
| | | c = self.mlp_o(torch.cat(c, dim=1)) |
| | |
| | | |
| | | # dot with gvec |
| | | # utt x frame x att_dim -> utt x frame |
| | | e = self.gvec( |
| | | torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv) |
| | | ).squeeze(2) |
| | | e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)).squeeze(2) |
| | | |
| | | # NOTE: consider zero padding when compute w. |
| | | if self.mask is None: |
| | |
| | | |
| | | # apply monotonic attention constraint (mainly for TTS) |
| | | if last_attended_idx is not None: |
| | | e = _apply_attention_constraint( |
| | | e, last_attended_idx, backward_window, forward_window |
| | | ) |
| | | e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window) |
| | | |
| | | w = F.softmax(scaling * e, dim=1) |
| | | |
| | |
| | | |
| | | # dot with gvec |
| | | # utt x frame x att_dim -> utt x frame |
| | | e = self.gvec( |
| | | torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) |
| | | ).squeeze(2) |
| | | e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) |
| | | |
| | | # NOTE consider zero padding when compute w. |
| | | if self.mask is None: |
| | |
| | | |
| | | # apply monotonic attention constraint (mainly for TTS) |
| | | if last_attended_idx is not None: |
| | | e = _apply_attention_constraint( |
| | | e, last_attended_idx, backward_window, forward_window |
| | | ) |
| | | e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window) |
| | | |
| | | w = F.softmax(scaling * e, dim=1) |
| | | |
| | | # forward attention |
| | | att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] |
| | | w = ( |
| | | self.trans_agent_prob * att_prev |
| | | + (1 - self.trans_agent_prob) * att_prev_shift |
| | | ) * w |
| | | w = (self.trans_agent_prob * att_prev + (1 - self.trans_agent_prob) * att_prev_shift) * w |
| | | # NOTE: clamp is needed to avoid nan gradient |
| | | w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) |
| | | |
| | |
| | | c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) |
| | | |
| | | # update transition agent prob |
| | | self.trans_agent_prob = torch.sigmoid( |
| | | self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)) |
| | | ) |
| | | self.trans_agent_prob = torch.sigmoid(self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1))) |
| | | |
| | | return c, w |
| | | |
| | |
| | | ) |
| | | att_list.append(att) |
| | | else: |
| | | raise ValueError( |
| | | "Number of encoders needs to be more than one. {}".format(num_encs) |
| | | ) |
| | | raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs)) |
| | | return att_list |
| | | |
| | | |
| | |
| | | att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy() |
| | | elif isinstance(att, (AttCov, AttCovLoc)): |
| | | # att_ws => list of list of previous attentions |
| | | att_ws = ( |
| | | torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy() |
| | | ) |
| | | att_ws = torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy() |
| | | elif isinstance(att, AttLocRec): |
| | | # att_ws => list of tuple of attention and hidden states |
| | | att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy() |