From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/sond/pooling/pooling_layers.py |   25 ++++++++++++-------------
 1 files changed, 12 insertions(+), 13 deletions(-)

diff --git a/funasr/models/sond/pooling/pooling_layers.py b/funasr/models/sond/pooling/pooling_layers.py
index 0aa10fe..0a426a9 100644
--- a/funasr/models/sond/pooling/pooling_layers.py
+++ b/funasr/models/sond/pooling/pooling_layers.py
@@ -59,8 +59,8 @@
 
 
 class ASTP(nn.Module):
-    """ Attentive statistics pooling: Channel- and context-dependent
-        statistics pooling, first used in ECAPA_TDNN.
+    """Attentive statistics pooling: Channel- and context-dependent
+    statistics pooling, first used in ECAPA_TDNN.
     """
 
     def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
@@ -71,14 +71,15 @@
         # need to transpose inputs.
         if global_context_att:
             self.linear1 = nn.Conv1d(
-                in_dim * 3, bottleneck_dim,
-                kernel_size=1)  # equals W and b in the paper
+                in_dim * 3, bottleneck_dim, kernel_size=1
+            )  # equals W and b in the paper
         else:
             self.linear1 = nn.Conv1d(
-                in_dim, bottleneck_dim,
-                kernel_size=1)  # equals W and b in the paper
-        self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
-                                 kernel_size=1)  # equals V and k in the paper
+                in_dim, bottleneck_dim, kernel_size=1
+            )  # equals W and b in the paper
+        self.linear2 = nn.Conv1d(
+            bottleneck_dim, in_dim, kernel_size=1
+        )  # equals V and k in the paper
 
     def forward(self, x):
         """
@@ -92,17 +93,15 @@
 
         if self.global_context_att:
             context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
-            context_std = torch.sqrt(
-                torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
+            context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
             x_in = torch.cat((x, context_mean, context_std), dim=1)
         else:
             x_in = x
 
         # DON'T use ReLU here! ReLU may be hard to converge.
-        alpha = torch.tanh(
-            self.linear1(x_in))  # alpha = F.relu(self.linear1(x_in))
+        alpha = torch.tanh(self.linear1(x_in))  # alpha = F.relu(self.linear1(x_in))
         alpha = torch.softmax(self.linear2(alpha), dim=2)
         mean = torch.sum(alpha * x, dim=2)
-        var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
+        var = torch.sum(alpha * (x**2), dim=2) - mean**2
         std = torch.sqrt(var.clamp(min=1e-10))
         return torch.cat([mean, std], dim=1)

--
Gitblit v1.9.1