From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/mfcca/encoder_layer_mfcca.py | 67 +++++++++++++++++----------------
1 files changed, 35 insertions(+), 32 deletions(-)
diff --git a/funasr/models/mfcca/encoder_layer_mfcca.py b/funasr/models/mfcca/encoder_layer_mfcca.py
index cc86f89..f16a4e6 100644
--- a/funasr/models/mfcca/encoder_layer_mfcca.py
+++ b/funasr/models/mfcca/encoder_layer_mfcca.py
@@ -15,7 +15,6 @@
from torch.autograd import Variable
-
class Encoder_Conformer_Layer(nn.Module):
"""Encoder layer module.
@@ -111,7 +110,6 @@
if self.normalize_before:
x = self.norm_mha(x)
-
if cache is None:
x_q = x
else:
@@ -120,12 +118,12 @@
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
- if self.cca_pos<2:
+ if self.cca_pos < 2:
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
- else:
+ else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
@@ -163,8 +161,6 @@
return (x, pos_emb), mask
return x, mask
-
-
class EncoderLayer(nn.Module):
@@ -209,18 +205,18 @@
self.encoder_cros_channel_atten = self_attn_cros_channel
self.encoder_csa = Encoder_Conformer_Layer(
- size,
- self_attn_conformer,
- feed_forward_csa,
- feed_forward_macaron_csa,
- conv_module_csa,
- dropout_rate,
- normalize_before,
- concat_after,
- cca_pos=0)
+ size,
+ self_attn_conformer,
+ feed_forward_csa,
+ feed_forward_macaron_csa,
+ conv_module_csa,
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ cca_pos=0,
+ )
self.norm_mha = LayerNorm(size) # for the MHA module
self.dropout = nn.Dropout(dropout_rate)
-
def forward(self, x_input, mask, channel_size, cache=None):
"""Compute encoded features.
@@ -245,26 +241,33 @@
x = self.norm_mha(x)
t_leng = x.size(1)
d_dim = x.size(2)
- x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
- x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
- pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
- pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
- x_pad = torch.cat([pad_before,x_new, pad_after], 1)
- x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
- x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
- x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
- x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
- x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
- x_new = x_new.reshape(-1,channel_size,d_dim)
- x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
+ x_new = x.reshape(-1, channel_size, t_leng, d_dim).transpose(1, 2) # x_new B*T * C * D
+ x_k_v = x_new.new(x_new.size(0), x_new.size(1), 5, x_new.size(2), x_new.size(3))
+ pad_before = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
+ x_new.type()
+ )
+ pad_after = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
+ x_new.type()
+ )
+ x_pad = torch.cat([pad_before, x_new, pad_after], 1)
+ x_k_v[:, :, 0, :, :] = x_pad[:, 0:-4, :, :]
+ x_k_v[:, :, 1, :, :] = x_pad[:, 1:-3, :, :]
+ x_k_v[:, :, 2, :, :] = x_pad[:, 2:-2, :, :]
+ x_k_v[:, :, 3, :, :] = x_pad[:, 3:-1, :, :]
+ x_k_v[:, :, 4, :, :] = x_pad[:, 4:, :, :]
+ x_new = x_new.reshape(-1, channel_size, d_dim)
+ x_k_v = x_k_v.reshape(-1, 5 * channel_size, d_dim)
x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
- x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
+ x_att = (
+ x_att.reshape(-1, t_leng, channel_size, d_dim)
+ .transpose(1, 2)
+ .reshape(-1, t_leng, d_dim)
+ )
x = residual + self.dropout(x_att)
if pos_emb is not None:
- x_input = (x, pos_emb)
+ x_input = (x, pos_emb)
else:
x_input = x
x_input, mask = self.encoder_csa(x_input, mask)
-
- return x_input, mask , channel_size
+ return x_input, mask, channel_size
--
Gitblit v1.9.1