From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/sond/encoder/ecapa_tdnn_encoder.py |   64 ++++++++++++-------------------
 1 files changed, 25 insertions(+), 39 deletions(-)

diff --git a/funasr/models/sond/encoder/ecapa_tdnn_encoder.py b/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
index 878a3c0..1af8b70 100644
--- a/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
+++ b/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
@@ -39,9 +39,7 @@
             if x.ndim == 3:
                 x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
             else:
-                x = x.reshape(
-                    shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
-                )
+                x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
 
         elif not self.skip_transpose:
             x = x.transpose(-1, 1)
@@ -105,9 +103,7 @@
             x = x.unsqueeze(1)
 
         if self.padding == "same":
-            x = self._manage_padding(
-                x, self.kernel_size, self.dilation, self.stride
-            )
+            x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
 
         elif self.padding == "causal":
             num_pad = (self.kernel_size - 1) * self.dilation
@@ -117,10 +113,7 @@
             pass
 
         else:
-            raise ValueError(
-                "Padding must be 'same', 'valid' or 'causal'. Got "
-                + self.padding
-            )
+            raise ValueError("Padding must be 'same', 'valid' or 'causal'. Got " + self.padding)
 
         wx = self.conv(x)
 
@@ -133,7 +126,11 @@
         return wx
 
     def _manage_padding(
-        self, x, kernel_size: int, dilation: int, stride: int,
+        self,
+        x,
+        kernel_size: int,
+        dilation: int,
+        stride: int,
     ):
         # Detecting input shape
         L_in = x.shape[-1]
@@ -147,8 +144,7 @@
         return x
 
     def _check_input_shape(self, shape):
-        """Checks the input shape and returns the number of input channels.
-        """
+        """Checks the input shape and returns the number of input channels."""
 
         if len(shape) == 2:
             self.unsqueeze = True
@@ -158,15 +154,12 @@
         elif len(shape) == 3:
             in_channels = shape[2]
         else:
-            raise ValueError(
-                "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
-            )
+            raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
 
         # Kernel size must be odd
         if self.kernel_size % 2 == 0:
             raise ValueError(
-                "The field kernel size must be an odd number. Got %s."
-                % (self.kernel_size)
+                "The field kernel size must be an odd number. Got %s." % (self.kernel_size)
             )
         return in_channels
 
@@ -200,9 +193,9 @@
 
     if max_len is None:
         max_len = length.max().long().item()  # using arange to generate mask
-    mask = torch.arange(
-        max_len, device=length.device, dtype=length.dtype
-    ).expand(len(length), max_len) < length.unsqueeze(1)
+    mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
+        len(length), max_len
+    ) < length.unsqueeze(1)
 
     if dtype is None:
         dtype = length.dtype
@@ -264,9 +257,7 @@
     torch.Size([8, 120, 64])
     """
 
-    def __init__(
-        self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
-    ):
+    def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
         super(Res2NetBlock, self).__init__()
         assert in_channels % scale == 0
         assert out_channels % scale == 0
@@ -326,13 +317,9 @@
     def __init__(self, in_channels, se_channels, out_channels):
         super(SEBlock, self).__init__()
 
-        self.conv1 = Conv1d(
-            in_channels=in_channels, out_channels=se_channels, kernel_size=1
-        )
+        self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
         self.relu = torch.nn.ReLU(inplace=True)
-        self.conv2 = Conv1d(
-            in_channels=se_channels, out_channels=out_channels, kernel_size=1
-        )
+        self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
         self.sigmoid = torch.nn.Sigmoid()
 
     def forward(self, x, lengths=None):
@@ -382,9 +369,7 @@
         else:
             self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
         self.tanh = nn.Tanh()
-        self.conv = Conv1d(
-            in_channels=attention_channels, out_channels=channels, kernel_size=1
-        )
+        self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
 
     def forward(self, x, lengths=None):
         """Calculates mean and std for a batch (input tensor).
@@ -398,9 +383,7 @@
 
         def _compute_statistics(x, m, dim=2, eps=self.eps):
             mean = (m * x).sum(dim)
-            std = torch.sqrt(
-                (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
-            )
+            std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
             return mean, std
 
         if lengths is None:
@@ -638,9 +621,12 @@
         for i in range(num_chunk):
             # B x C
             st, ed = i * self.window_shift, i * self.window_shift + self.window_size
-            x = self.asp(x[:, :, st: ed],
-                         lengths=torch.clamp(lengths - i, 0, self.window_size)
-                         if lengths is not None else None)
+            x = self.asp(
+                x[:, :, st:ed],
+                lengths=(
+                    torch.clamp(lengths - i, 0, self.window_size) if lengths is not None else None
+                ),
+            )
             x = self.asp_bn(x)
             x = self.fc(x)
             stat_list.append(x)

--
Gitblit v1.9.1