From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/transformer/utils/lightconv2d.py |   14 +++++---------
 1 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/funasr/models/transformer/utils/lightconv2d.py b/funasr/models/transformer/utils/lightconv2d.py
index 294d232..60e5af5 100644
--- a/funasr/models/transformer/utils/lightconv2d.py
+++ b/funasr/models/transformer/utils/lightconv2d.py
@@ -50,9 +50,7 @@
         self.act = nn.GLU()
 
         # lightconv related
-        self.weight = nn.Parameter(
-            torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
-        )
+        self.weight = nn.Parameter(torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1))
         self.weight_f = nn.Parameter(torch.Tensor(1, 1, kernel_size).uniform_(0, 1))
         self.use_bias = use_bias
         if self.use_bias:
@@ -93,9 +91,9 @@
         # convolution along frequency axis
         weight_f = F.softmax(self.weight_f, dim=-1)
         weight_f = F.dropout(weight_f, self.dropout_rate, training=self.training)
-        weight_new = torch.zeros(
-            B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype
-        ).copy_(weight_f)
+        weight_new = torch.zeros(B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype).copy_(
+            weight_f
+        )
         xf = F.conv1d(
             x.view(1, B * T, C), weight_new, padding=self.padding_size, groups=B * T
         ).view(B, T, C)
@@ -107,9 +105,7 @@
             self.kernel_mask = self.kernel_mask.to(x.device)
             weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
         weight = F.softmax(weight, dim=-1)
-        x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
-            B, C, T
-        )
+        x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(B, C, T)
         if self.use_bias:
             x = x + self.bias.view(1, -1, 1)
         x = x.transpose(1, 2)  # B x T x C

--
Gitblit v1.9.1