From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/language_model/rnn/attentions.py |  162 +++++++++++++-----------------------------------------
 1 files changed, 39 insertions(+), 123 deletions(-)

diff --git a/funasr/models/language_model/rnn/attentions.py b/funasr/models/language_model/rnn/attentions.py
index 30cdff9..f5c450c 100644
--- a/funasr/models/language_model/rnn/attentions.py
+++ b/funasr/models/language_model/rnn/attentions.py
@@ -10,9 +10,7 @@
 from funasr.models.transformer.utils.nets_utils import to_device
 
 
-def _apply_attention_constraint(
-    e, last_attended_idx, backward_window=1, forward_window=3
-):
+def _apply_attention_constraint(e, last_attended_idx, backward_window=1, forward_window=3):
     """Apply monotonic attention constraint.
 
     This function apply the monotonic attention constraint
@@ -84,9 +82,7 @@
             mask = 1.0 - make_pad_mask(enc_hs_len).float()
             att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1)
             att_prev = att_prev.to(self.enc_h)
-            self.c = torch.sum(
-                self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1
-            )
+            self.c = torch.sum(self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1)
 
         return self.c, att_prev
 
@@ -150,8 +146,7 @@
             dec_z = dec_z.view(batch, self.dunits)
 
         e = torch.sum(
-            self.pre_compute_enc_h
-            * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
+            self.pre_compute_enc_h * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
             dim=2,
         )  # utt x frame
 
@@ -262,9 +257,7 @@
         and not store pre_compute_enc_h
     """
 
-    def __init__(
-        self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
-    ):
+    def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
         super(AttLoc, self).__init__()
         self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
         self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -338,9 +331,7 @@
         # initialize attention weight with uniform dist.
         if att_prev is None:
             # if no bias, 0 0-pad goes 0
-            att_prev = 1.0 - make_pad_mask(enc_hs_len).to(
-                device=dec_z.device, dtype=dec_z.dtype
-            )
+            att_prev = 1.0 - make_pad_mask(enc_hs_len).to(device=dec_z.device, dtype=dec_z.dtype)
             att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
 
         # att_prev: utt x frame -> utt x 1 x 1 x frame
@@ -356,9 +347,7 @@
 
         # dot with gvec
         # utt x frame x att_dim -> utt x frame
-        e = self.gvec(
-            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
-        ).squeeze(2)
+        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
 
         # NOTE: consider zero padding when compute w.
         if self.mask is None:
@@ -367,9 +356,7 @@
 
         # apply monotonic attention constraint (mainly for TTS)
         if last_attended_idx is not None:
-            e = _apply_attention_constraint(
-                e, last_attended_idx, backward_window, forward_window
-            )
+            e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
 
         w = F.softmax(scaling * e, dim=1)
 
@@ -446,12 +433,8 @@
         # initialize attention weight with uniform dist.
         if att_prev_list is None:
             # if no bias, 0 0-pad goes 0
-            att_prev_list = to_device(
-                enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())
-            )
-            att_prev_list = [
-                att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)
-            ]
+            att_prev_list = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()))
+            att_prev_list = [att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)]
 
         # att_prev_list: L' * [B x T] => cov_vec B x T
         cov_vec = sum(att_prev_list)
@@ -463,9 +446,7 @@
 
         # dot with gvec
         # utt x frame x att_dim -> utt x frame
-        e = self.gvec(
-            torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)
-        ).squeeze(2)
+        e = self.gvec(torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
 
         # NOTE consider zero padding when compute w.
         if self.mask is None:
@@ -499,9 +480,7 @@
         flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
     """
 
-    def __init__(
-        self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False
-    ):
+    def __init__(self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False):
         super(AttLoc2D, self).__init__()
         self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
         self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -580,9 +559,7 @@
 
         # dot with gvec
         # utt x frame x att_dim -> utt x frame
-        e = self.gvec(
-            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
-        ).squeeze(2)
+        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
 
         # NOTE consider zero padding when compute w.
         if self.mask is None:
@@ -619,9 +596,7 @@
         flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
     """
 
-    def __init__(
-        self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
-    ):
+    def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
         super(AttLocRec, self).__init__()
         self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
         self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -740,9 +715,7 @@
         flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
     """
 
-    def __init__(
-        self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
-    ):
+    def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
         super(AttCovLoc, self).__init__()
         self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
         self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -804,9 +777,7 @@
         if att_prev_list is None:
             # if no bias, 0 0-pad goes 0
             mask = 1.0 - make_pad_mask(enc_hs_len).float()
-            att_prev_list = [
-                to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
-            ]
+            att_prev_list = [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
 
         # att_prev_list: L' * [B x T] => cov_vec B x T
         cov_vec = sum(att_prev_list)
@@ -823,9 +794,7 @@
 
         # dot with gvec
         # utt x frame x att_dim -> utt x frame
-        e = self.gvec(
-            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
-        ).squeeze(2)
+        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
 
         # NOTE consider zero padding when compute w.
         if self.mask is None:
@@ -908,17 +877,14 @@
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
             self.pre_compute_k = [
-                torch.tanh(self.mlp_k[h](self.enc_h))
-                for h in six.moves.range(self.aheads)
+                torch.tanh(self.mlp_k[h](self.enc_h)) for h in six.moves.range(self.aheads)
             ]
 
         if self.pre_compute_v is None or self.han_mode:
             self.enc_h = enc_hs_pad  # utt x frame x hdim
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
-            self.pre_compute_v = [
-                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
-            ]
+            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
 
         if dec_z is None:
             dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -943,11 +909,7 @@
             # weighted sum over flames
             # utt x hdim
             # NOTE use bmm instead of sum(*)
-            c += [
-                torch.sum(
-                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
-                )
-            ]
+            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
 
         # concat all of c
         c = self.mlp_o(torch.cat(c, dim=1))
@@ -1024,17 +986,13 @@
             self.enc_h = enc_hs_pad  # utt x frame x hdim
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
-            self.pre_compute_k = [
-                self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
-            ]
+            self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
 
         if self.pre_compute_v is None or self.han_mode:
             self.enc_h = enc_hs_pad  # utt x frame x hdim
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
-            self.pre_compute_v = [
-                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
-            ]
+            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
 
         if dec_z is None:
             dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1046,8 +1004,7 @@
         for h in six.moves.range(self.aheads):
             e = self.gvec[h](
                 torch.tanh(
-                    self.pre_compute_k[h]
-                    + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
+                    self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
                 )
             ).squeeze(2)
 
@@ -1060,11 +1017,7 @@
             # weighted sum over flames
             # utt x hdim
             # NOTE use bmm instead of sum(*)
-            c += [
-                torch.sum(
-                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
-                )
-            ]
+            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
 
         # concat all of c
         c = self.mlp_o(torch.cat(c, dim=1))
@@ -1167,17 +1120,13 @@
             self.enc_h = enc_hs_pad  # utt x frame x hdim
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
-            self.pre_compute_k = [
-                self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
-            ]
+            self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
 
         if self.pre_compute_v is None or self.han_mode:
             self.enc_h = enc_hs_pad  # utt x frame x hdim
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
-            self.pre_compute_v = [
-                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
-            ]
+            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
 
         if dec_z is None:
             dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1189,9 +1138,7 @@
             for _ in six.moves.range(self.aheads):
                 # if no bias, 0 0-pad goes 0
                 mask = 1.0 - make_pad_mask(enc_hs_len).float()
-                att_prev += [
-                    to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
-                ]
+                att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
 
         c = []
         w = []
@@ -1217,11 +1164,7 @@
             # weighted sum over flames
             # utt x hdim
             # NOTE use bmm instead of sum(*)
-            c += [
-                torch.sum(
-                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
-                )
-            ]
+            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
 
         # concat all of c
         c = self.mlp_o(torch.cat(c, dim=1))
@@ -1323,17 +1266,13 @@
             self.enc_h = enc_hs_pad  # utt x frame x hdim
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
-            self.pre_compute_k = [
-                self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
-            ]
+            self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
 
         if self.pre_compute_v is None or self.han_mode:
             self.enc_h = enc_hs_pad  # utt x frame x hdim
             self.h_length = self.enc_h.size(1)
             # utt x frame x att_dim
-            self.pre_compute_v = [
-                self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
-            ]
+            self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
 
         if dec_z is None:
             dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1345,9 +1284,7 @@
             for _ in six.moves.range(self.aheads):
                 # if no bias, 0 0-pad goes 0
                 mask = 1.0 - make_pad_mask(enc_hs_len).float()
-                att_prev += [
-                    to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
-                ]
+                att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
 
         c = []
         w = []
@@ -1373,11 +1310,7 @@
             # weighted sum over flames
             # utt x hdim
             # NOTE use bmm instead of sum(*)
-            c += [
-                torch.sum(
-                    self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
-                )
-            ]
+            c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
 
         # concat all of c
         c = self.mlp_o(torch.cat(c, dim=1))
@@ -1484,9 +1417,7 @@
 
         # dot with gvec
         # utt x frame x att_dim -> utt x frame
-        e = self.gvec(
-            torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)
-        ).squeeze(2)
+        e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)).squeeze(2)
 
         # NOTE: consider zero padding when compute w.
         if self.mask is None:
@@ -1495,9 +1426,7 @@
 
         # apply monotonic attention constraint (mainly for TTS)
         if last_attended_idx is not None:
-            e = _apply_attention_constraint(
-                e, last_attended_idx, backward_window, forward_window
-            )
+            e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
 
         w = F.softmax(scaling * e, dim=1)
 
@@ -1619,9 +1548,7 @@
 
         # dot with gvec
         # utt x frame x att_dim -> utt x frame
-        e = self.gvec(
-            torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
-        ).squeeze(2)
+        e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
 
         # NOTE consider zero padding when compute w.
         if self.mask is None:
@@ -1630,18 +1557,13 @@
 
         # apply monotonic attention constraint (mainly for TTS)
         if last_attended_idx is not None:
-            e = _apply_attention_constraint(
-                e, last_attended_idx, backward_window, forward_window
-            )
+            e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
 
         w = F.softmax(scaling * e, dim=1)
 
         # forward attention
         att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
-        w = (
-            self.trans_agent_prob * att_prev
-            + (1 - self.trans_agent_prob) * att_prev_shift
-        ) * w
+        w = (self.trans_agent_prob * att_prev + (1 - self.trans_agent_prob) * att_prev_shift) * w
         # NOTE: clamp is needed to avoid nan gradient
         w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
 
@@ -1651,9 +1573,7 @@
         c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
 
         # update transition agent prob
-        self.trans_agent_prob = torch.sigmoid(
-            self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1))
-        )
+        self.trans_agent_prob = torch.sigmoid(self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)))
 
         return c, w
 
@@ -1717,9 +1637,7 @@
                 )
                 att_list.append(att)
     else:
-        raise ValueError(
-            "Number of encoders needs to be more than one. {}".format(num_encs)
-        )
+        raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
     return att_list
 
 
@@ -1785,9 +1703,7 @@
         att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy()
     elif isinstance(att, (AttCov, AttCovLoc)):
         # att_ws => list of list of previous attentions
-        att_ws = (
-            torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
-        )
+        att_ws = torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
     elif isinstance(att, AttLocRec):
         # att_ws => list of tuple of attention and hidden states
         att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy()

--
Gitblit v1.9.1