liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/language_model/rnn/attentions.py
@@ -10,9 +10,7 @@
from funasr.models.transformer.utils.nets_utils import to_device
def _apply_attention_constraint(
    e, last_attended_idx, backward_window=1, forward_window=3
):
def _apply_attention_constraint(e, last_attended_idx, backward_window=1, forward_window=3):
    """Apply monotonic attention constraint.
    This function apply the monotonic attention constraint
@@ -84,9 +82,7 @@
            mask = 1.0 - make_pad_mask(enc_hs_len).float()
            att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1)
            att_prev = att_prev.to(self.enc_h)
            self.c = torch.sum(
                self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1
            )
            self.c = torch.sum(self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1)
        return self.c, att_prev
@@ -150,8 +146,7 @@
            dec_z = dec_z.view(batch, self.dunits)
        e = torch.sum(
            self.pre_compute_enc_h
            * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
            self.pre_compute_enc_h * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
            dim=2,
        )  # utt x frame
@@ -262,9 +257,7 @@
        and not store pre_compute_enc_h
    """
    def __init__(
        self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
    ):
    def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
        super(AttLoc, self).__init__()
        self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
        self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -338,9 +331,7 @@
        # initialize attention weight with uniform dist.
        if att_prev is None:
            # if no bias, 0 0-pad goes 0
            att_prev = 1.0 - make_pad_mask(enc_hs_len).to(
                device=dec_z.device, dtype=dec_z.dtype
            )
            att_prev = 1.0 - make_pad_mask(enc_hs_len).to(device=dec_z.device, dtype=dec_z.dtype)
            att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
        # att_prev: utt x frame -> utt x 1 x 1 x frame
@@ -356,9 +347,7 @@
        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        e = self.gvec(
            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
        ).squeeze(2)
        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
        # NOTE: consider zero padding when compute w.
        if self.mask is None:
@@ -367,9 +356,7 @@
        # apply monotonic attention constraint (mainly for TTS)
        if last_attended_idx is not None:
            e = _apply_attention_constraint(
                e, last_attended_idx, backward_window, forward_window
            )
            e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
        w = F.softmax(scaling * e, dim=1)
@@ -446,12 +433,8 @@
        # initialize attention weight with uniform dist.
        if att_prev_list is None:
            # if no bias, 0 0-pad goes 0
            att_prev_list = to_device(
                enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())
            )
            att_prev_list = [
                att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)
            ]
            att_prev_list = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()))
            att_prev_list = [att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)]
        # att_prev_list: L' * [B x T] => cov_vec B x T
        cov_vec = sum(att_prev_list)
@@ -463,9 +446,7 @@
        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        e = self.gvec(
            torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)
        ).squeeze(2)
        e = self.gvec(torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
        # NOTE consider zero padding when compute w.
        if self.mask is None:
@@ -499,9 +480,7 @@
        flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
    """
    def __init__(
        self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False
    ):
    def __init__(self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False):
        super(AttLoc2D, self).__init__()
        self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
        self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -580,9 +559,7 @@
        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        e = self.gvec(
            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
        ).squeeze(2)
        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
        # NOTE consider zero padding when compute w.
        if self.mask is None:
@@ -619,9 +596,7 @@
        flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
    """
    def __init__(
        self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
    ):
    def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
        super(AttLocRec, self).__init__()
        self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
        self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -740,9 +715,7 @@
        flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
    """
    def __init__(
        self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
    ):
    def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
        super(AttCovLoc, self).__init__()
        self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
        self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -804,9 +777,7 @@
        if att_prev_list is None:
            # if no bias, 0 0-pad goes 0
            mask = 1.0 - make_pad_mask(enc_hs_len).float()
            att_prev_list = [
                to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
            ]
            att_prev_list = [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
        # att_prev_list: L' * [B x T] => cov_vec B x T
        cov_vec = sum(att_prev_list)
@@ -823,9 +794,7 @@
        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        e = self.gvec(
            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
        ).squeeze(2)
        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
        # NOTE consider zero padding when compute w.
        if self.mask is None:
@@ -908,17 +877,14 @@
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_k = [
                torch.tanh(self.mlp_k[h](self.enc_h))
                for h in six.moves.range(self.aheads)
                torch.tanh(self.mlp_k[h](self.enc_h)) for h in six.moves.range(self.aheads)
            ]
        if self.pre_compute_v is None or self.han_mode:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_v = [
                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
            ]
            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
        if dec_z is None:
            dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -943,11 +909,7 @@
            # weighted sum over flames
            # utt x hdim
            # NOTE use bmm instead of sum(*)
            c += [
                torch.sum(
                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
                )
            ]
            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
        # concat all of c
        c = self.mlp_o(torch.cat(c, dim=1))
@@ -1024,17 +986,13 @@
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_k = [
                self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
            ]
            self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
        if self.pre_compute_v is None or self.han_mode:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_v = [
                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
            ]
            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
        if dec_z is None:
            dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1046,8 +1004,7 @@
        for h in six.moves.range(self.aheads):
            e = self.gvec[h](
                torch.tanh(
                    self.pre_compute_k[h]
                    + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
                    self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
                )
            ).squeeze(2)
@@ -1060,11 +1017,7 @@
            # weighted sum over flames
            # utt x hdim
            # NOTE use bmm instead of sum(*)
            c += [
                torch.sum(
                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
                )
            ]
            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
        # concat all of c
        c = self.mlp_o(torch.cat(c, dim=1))
@@ -1167,17 +1120,13 @@
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_k = [
                self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
            ]
            self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
        if self.pre_compute_v is None or self.han_mode:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_v = [
                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
            ]
            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
        if dec_z is None:
            dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1189,9 +1138,7 @@
            for _ in six.moves.range(self.aheads):
                # if no bias, 0 0-pad goes 0
                mask = 1.0 - make_pad_mask(enc_hs_len).float()
                att_prev += [
                    to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
                ]
                att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
        c = []
        w = []
@@ -1217,11 +1164,7 @@
            # weighted sum over flames
            # utt x hdim
            # NOTE use bmm instead of sum(*)
            c += [
                torch.sum(
                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
                )
            ]
            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
        # concat all of c
        c = self.mlp_o(torch.cat(c, dim=1))
@@ -1323,17 +1266,13 @@
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_k = [
                self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
            ]
            self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
        if self.pre_compute_v is None or self.han_mode:
            self.enc_h = enc_hs_pad  # utt x frame x hdim
            self.h_length = self.enc_h.size(1)
            # utt x frame x att_dim
            self.pre_compute_v = [
                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
            ]
            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
        if dec_z is None:
            dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1345,9 +1284,7 @@
            for _ in six.moves.range(self.aheads):
                # if no bias, 0 0-pad goes 0
                mask = 1.0 - make_pad_mask(enc_hs_len).float()
                att_prev += [
                    to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
                ]
                att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
        c = []
        w = []
@@ -1373,11 +1310,7 @@
            # weighted sum over flames
            # utt x hdim
            # NOTE use bmm instead of sum(*)
            c += [
                torch.sum(
                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
                )
            ]
            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
        # concat all of c
        c = self.mlp_o(torch.cat(c, dim=1))
@@ -1484,9 +1417,7 @@
        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        e = self.gvec(
            torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)
        ).squeeze(2)
        e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)).squeeze(2)
        # NOTE: consider zero padding when compute w.
        if self.mask is None:
@@ -1495,9 +1426,7 @@
        # apply monotonic attention constraint (mainly for TTS)
        if last_attended_idx is not None:
            e = _apply_attention_constraint(
                e, last_attended_idx, backward_window, forward_window
            )
            e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
        w = F.softmax(scaling * e, dim=1)
@@ -1619,9 +1548,7 @@
        # dot with gvec
        # utt x frame x att_dim -> utt x frame
        e = self.gvec(
            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
        ).squeeze(2)
        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
        # NOTE consider zero padding when compute w.
        if self.mask is None:
@@ -1630,18 +1557,13 @@
        # apply monotonic attention constraint (mainly for TTS)
        if last_attended_idx is not None:
            e = _apply_attention_constraint(
                e, last_attended_idx, backward_window, forward_window
            )
            e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
        w = F.softmax(scaling * e, dim=1)
        # forward attention
        att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
        w = (
            self.trans_agent_prob * att_prev
            + (1 - self.trans_agent_prob) * att_prev_shift
        ) * w
        w = (self.trans_agent_prob * att_prev + (1 - self.trans_agent_prob) * att_prev_shift) * w
        # NOTE: clamp is needed to avoid nan gradient
        w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
@@ -1651,9 +1573,7 @@
        c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
        # update transition agent prob
        self.trans_agent_prob = torch.sigmoid(
            self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1))
        )
        self.trans_agent_prob = torch.sigmoid(self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)))
        return c, w
@@ -1717,9 +1637,7 @@
                )
                att_list.append(att)
    else:
        raise ValueError(
            "Number of encoders needs to be more than one. {}".format(num_encs)
        )
        raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
    return att_list
@@ -1785,9 +1703,7 @@
        att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy()
    elif isinstance(att, (AttCov, AttCovLoc)):
        # att_ws => list of list of previous attentions
        att_ws = (
            torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
        )
        att_ws = torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
    elif isinstance(att, AttLocRec):
        # att_ws => list of tuple of attention and hidden states
        att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy()