From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/mfcca/encoder_layer_mfcca.py |   67 +++++++++++++++++----------------
 1 files changed, 35 insertions(+), 32 deletions(-)

diff --git a/funasr/models/mfcca/encoder_layer_mfcca.py b/funasr/models/mfcca/encoder_layer_mfcca.py
index cc86f89..f16a4e6 100644
--- a/funasr/models/mfcca/encoder_layer_mfcca.py
+++ b/funasr/models/mfcca/encoder_layer_mfcca.py
@@ -15,7 +15,6 @@
 from torch.autograd import Variable
 
 
-
 class Encoder_Conformer_Layer(nn.Module):
     """Encoder layer module.
 
@@ -111,7 +110,6 @@
         if self.normalize_before:
             x = self.norm_mha(x)
 
-
         if cache is None:
             x_q = x
         else:
@@ -120,12 +118,12 @@
             residual = residual[:, -1:, :]
             mask = None if mask is None else mask[:, -1:, :]
 
-        if self.cca_pos<2:
+        if self.cca_pos < 2:
             if pos_emb is not None:
                 x_att = self.self_attn(x_q, x, x, pos_emb, mask)
             else:
                 x_att = self.self_attn(x_q, x, x, mask)
-        else:    
+        else:
             x_att = self.self_attn(x_q, x, x, mask)
 
         if self.concat_after:
@@ -163,8 +161,6 @@
             return (x, pos_emb), mask
 
         return x, mask
-
-
 
 
 class EncoderLayer(nn.Module):
@@ -209,18 +205,18 @@
 
         self.encoder_cros_channel_atten = self_attn_cros_channel
         self.encoder_csa = Encoder_Conformer_Layer(
-                size,
-                self_attn_conformer,
-                feed_forward_csa,
-                feed_forward_macaron_csa,
-                conv_module_csa,
-                dropout_rate,
-                normalize_before,
-                concat_after,
-                cca_pos=0)
+            size,
+            self_attn_conformer,
+            feed_forward_csa,
+            feed_forward_macaron_csa,
+            conv_module_csa,
+            dropout_rate,
+            normalize_before,
+            concat_after,
+            cca_pos=0,
+        )
         self.norm_mha = LayerNorm(size)  # for the MHA module
         self.dropout = nn.Dropout(dropout_rate)
-
 
     def forward(self, x_input, mask, channel_size, cache=None):
         """Compute encoded features.
@@ -245,26 +241,33 @@
         x = self.norm_mha(x)
         t_leng = x.size(1)
         d_dim = x.size(2)
-        x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
-        x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
-        pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
-        pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
-        x_pad = torch.cat([pad_before,x_new, pad_after], 1)
-        x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
-        x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
-        x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
-        x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
-        x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
-        x_new = x_new.reshape(-1,channel_size,d_dim)
-        x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
+        x_new = x.reshape(-1, channel_size, t_leng, d_dim).transpose(1, 2)  # x_new B*T * C * D
+        x_k_v = x_new.new(x_new.size(0), x_new.size(1), 5, x_new.size(2), x_new.size(3))
+        pad_before = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
+            x_new.type()
+        )
+        pad_after = Variable(torch.zeros(x_new.size(0), 2, x_new.size(2), x_new.size(3))).type(
+            x_new.type()
+        )
+        x_pad = torch.cat([pad_before, x_new, pad_after], 1)
+        x_k_v[:, :, 0, :, :] = x_pad[:, 0:-4, :, :]
+        x_k_v[:, :, 1, :, :] = x_pad[:, 1:-3, :, :]
+        x_k_v[:, :, 2, :, :] = x_pad[:, 2:-2, :, :]
+        x_k_v[:, :, 3, :, :] = x_pad[:, 3:-1, :, :]
+        x_k_v[:, :, 4, :, :] = x_pad[:, 4:, :, :]
+        x_new = x_new.reshape(-1, channel_size, d_dim)
+        x_k_v = x_k_v.reshape(-1, 5 * channel_size, d_dim)
         x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
-        x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
+        x_att = (
+            x_att.reshape(-1, t_leng, channel_size, d_dim)
+            .transpose(1, 2)
+            .reshape(-1, t_leng, d_dim)
+        )
         x = residual + self.dropout(x_att)
         if pos_emb is not None:
-            x_input =  (x, pos_emb)
+            x_input = (x, pos_emb)
         else:
             x_input = x
         x_input, mask = self.encoder_csa(x_input, mask)
 
-
-        return x_input, mask , channel_size
+        return x_input, mask, channel_size

--
Gitblit v1.9.1