From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/language_model/rnn/attentions.py | 162 +++++++++++++-----------------------------------------
1 files changed, 39 insertions(+), 123 deletions(-)
diff --git a/funasr/models/language_model/rnn/attentions.py b/funasr/models/language_model/rnn/attentions.py
index 30cdff9..f5c450c 100644
--- a/funasr/models/language_model/rnn/attentions.py
+++ b/funasr/models/language_model/rnn/attentions.py
@@ -10,9 +10,7 @@
from funasr.models.transformer.utils.nets_utils import to_device
-def _apply_attention_constraint(
- e, last_attended_idx, backward_window=1, forward_window=3
-):
+def _apply_attention_constraint(e, last_attended_idx, backward_window=1, forward_window=3):
"""Apply monotonic attention constraint.
This function apply the monotonic attention constraint
@@ -84,9 +82,7 @@
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.to(self.enc_h)
- self.c = torch.sum(
- self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1
- )
+ self.c = torch.sum(self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1)
return self.c, att_prev
@@ -150,8 +146,7 @@
dec_z = dec_z.view(batch, self.dunits)
e = torch.sum(
- self.pre_compute_enc_h
- * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
+ self.pre_compute_enc_h * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
dim=2,
) # utt x frame
@@ -262,9 +257,7 @@
and not store pre_compute_enc_h
"""
- def __init__(
- self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
- ):
+ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -338,9 +331,7 @@
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
- att_prev = 1.0 - make_pad_mask(enc_hs_len).to(
- device=dec_z.device, dtype=dec_z.dtype
- )
+ att_prev = 1.0 - make_pad_mask(enc_hs_len).to(device=dec_z.device, dtype=dec_z.dtype)
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# att_prev: utt x frame -> utt x 1 x 1 x frame
@@ -356,9 +347,7 @@
# dot with gvec
# utt x frame x att_dim -> utt x frame
- e = self.gvec(
- torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
- ).squeeze(2)
+ e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
@@ -367,9 +356,7 @@
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
- e = _apply_attention_constraint(
- e, last_attended_idx, backward_window, forward_window
- )
+ e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
@@ -446,12 +433,8 @@
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
- att_prev_list = to_device(
- enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())
- )
- att_prev_list = [
- att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)
- ]
+ att_prev_list = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()))
+ att_prev_list = [att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
@@ -463,9 +446,7 @@
# dot with gvec
# utt x frame x att_dim -> utt x frame
- e = self.gvec(
- torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)
- ).squeeze(2)
+ e = self.gvec(torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
@@ -499,9 +480,7 @@
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
- def __init__(
- self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False
- ):
+ def __init__(self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False):
super(AttLoc2D, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -580,9 +559,7 @@
# dot with gvec
# utt x frame x att_dim -> utt x frame
- e = self.gvec(
- torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
- ).squeeze(2)
+ e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
@@ -619,9 +596,7 @@
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
- def __init__(
- self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
- ):
+ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttLocRec, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -740,9 +715,7 @@
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
- def __init__(
- self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
- ):
+ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttCovLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
@@ -804,9 +777,7 @@
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
- att_prev_list = [
- to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
- ]
+ att_prev_list = [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
@@ -823,9 +794,7 @@
# dot with gvec
# utt x frame x att_dim -> utt x frame
- e = self.gvec(
- torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
- ).squeeze(2)
+ e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
@@ -908,17 +877,14 @@
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
- torch.tanh(self.mlp_k[h](self.enc_h))
- for h in six.moves.range(self.aheads)
+ torch.tanh(self.mlp_k[h](self.enc_h)) for h in six.moves.range(self.aheads)
]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
- self.pre_compute_v = [
- self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
- ]
+ self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -943,11 +909,7 @@
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
- c += [
- torch.sum(
- self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
- )
- ]
+ c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
@@ -1024,17 +986,13 @@
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
- self.pre_compute_k = [
- self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
- ]
+ self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
- self.pre_compute_v = [
- self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
- ]
+ self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1046,8 +1004,7 @@
for h in six.moves.range(self.aheads):
e = self.gvec[h](
torch.tanh(
- self.pre_compute_k[h]
- + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
+ self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
)
).squeeze(2)
@@ -1060,11 +1017,7 @@
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
- c += [
- torch.sum(
- self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
- )
- ]
+ c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
@@ -1167,17 +1120,13 @@
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
- self.pre_compute_k = [
- self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
- ]
+ self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
- self.pre_compute_v = [
- self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
- ]
+ self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1189,9 +1138,7 @@
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
- att_prev += [
- to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
- ]
+ att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
c = []
w = []
@@ -1217,11 +1164,7 @@
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
- c += [
- torch.sum(
- self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
- )
- ]
+ c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
@@ -1323,17 +1266,13 @@
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
- self.pre_compute_k = [
- self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)
- ]
+ self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
- self.pre_compute_v = [
- self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)
- ]
+ self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
@@ -1345,9 +1284,7 @@
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
- att_prev += [
- to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
- ]
+ att_prev += [to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))]
c = []
w = []
@@ -1373,11 +1310,7 @@
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
- c += [
- torch.sum(
- self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
- )
- ]
+ c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
@@ -1484,9 +1417,7 @@
# dot with gvec
# utt x frame x att_dim -> utt x frame
- e = self.gvec(
- torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)
- ).squeeze(2)
+ e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
@@ -1495,9 +1426,7 @@
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
- e = _apply_attention_constraint(
- e, last_attended_idx, backward_window, forward_window
- )
+ e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
@@ -1619,9 +1548,7 @@
# dot with gvec
# utt x frame x att_dim -> utt x frame
- e = self.gvec(
- torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
- ).squeeze(2)
+ e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
@@ -1630,18 +1557,13 @@
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
- e = _apply_attention_constraint(
- e, last_attended_idx, backward_window, forward_window
- )
+ e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
- w = (
- self.trans_agent_prob * att_prev
- + (1 - self.trans_agent_prob) * att_prev_shift
- ) * w
+ w = (self.trans_agent_prob * att_prev + (1 - self.trans_agent_prob) * att_prev_shift) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
@@ -1651,9 +1573,7 @@
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update transition agent prob
- self.trans_agent_prob = torch.sigmoid(
- self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1))
- )
+ self.trans_agent_prob = torch.sigmoid(self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)))
return c, w
@@ -1717,9 +1637,7 @@
)
att_list.append(att)
else:
- raise ValueError(
- "Number of encoders needs to be more than one. {}".format(num_encs)
- )
+ raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
return att_list
@@ -1785,9 +1703,7 @@
att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(att, (AttCov, AttCovLoc)):
# att_ws => list of list of previous attentions
- att_ws = (
- torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
- )
+ att_ws = torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
elif isinstance(att, AttLocRec):
# att_ws => list of tuple of attention and hidden states
att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy()
--
Gitblit v1.9.1