From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/sond/encoder/ecapa_tdnn_encoder.py | 64 ++++++++++++-------------------
1 files changed, 25 insertions(+), 39 deletions(-)
diff --git a/funasr/models/sond/encoder/ecapa_tdnn_encoder.py b/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
index 878a3c0..1af8b70 100644
--- a/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
+++ b/funasr/models/sond/encoder/ecapa_tdnn_encoder.py
@@ -39,9 +39,7 @@
if x.ndim == 3:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
else:
- x = x.reshape(
- shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
- )
+ x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
elif not self.skip_transpose:
x = x.transpose(-1, 1)
@@ -105,9 +103,7 @@
x = x.unsqueeze(1)
if self.padding == "same":
- x = self._manage_padding(
- x, self.kernel_size, self.dilation, self.stride
- )
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
elif self.padding == "causal":
num_pad = (self.kernel_size - 1) * self.dilation
@@ -117,10 +113,7 @@
pass
else:
- raise ValueError(
- "Padding must be 'same', 'valid' or 'causal'. Got "
- + self.padding
- )
+ raise ValueError("Padding must be 'same', 'valid' or 'causal'. Got " + self.padding)
wx = self.conv(x)
@@ -133,7 +126,11 @@
return wx
def _manage_padding(
- self, x, kernel_size: int, dilation: int, stride: int,
+ self,
+ x,
+ kernel_size: int,
+ dilation: int,
+ stride: int,
):
# Detecting input shape
L_in = x.shape[-1]
@@ -147,8 +144,7 @@
return x
def _check_input_shape(self, shape):
- """Checks the input shape and returns the number of input channels.
- """
+ """Checks the input shape and returns the number of input channels."""
if len(shape) == 2:
self.unsqueeze = True
@@ -158,15 +154,12 @@
elif len(shape) == 3:
in_channels = shape[2]
else:
- raise ValueError(
- "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
- )
+ raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
# Kernel size must be odd
if self.kernel_size % 2 == 0:
raise ValueError(
- "The field kernel size must be an odd number. Got %s."
- % (self.kernel_size)
+ "The field kernel size must be an odd number. Got %s." % (self.kernel_size)
)
return in_channels
@@ -200,9 +193,9 @@
if max_len is None:
max_len = length.max().long().item() # using arange to generate mask
- mask = torch.arange(
- max_len, device=length.device, dtype=length.dtype
- ).expand(len(length), max_len) < length.unsqueeze(1)
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
+ len(length), max_len
+ ) < length.unsqueeze(1)
if dtype is None:
dtype = length.dtype
@@ -264,9 +257,7 @@
torch.Size([8, 120, 64])
"""
- def __init__(
- self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
- ):
+ def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
super(Res2NetBlock, self).__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
@@ -326,13 +317,9 @@
def __init__(self, in_channels, se_channels, out_channels):
super(SEBlock, self).__init__()
- self.conv1 = Conv1d(
- in_channels=in_channels, out_channels=se_channels, kernel_size=1
- )
+ self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
self.relu = torch.nn.ReLU(inplace=True)
- self.conv2 = Conv1d(
- in_channels=se_channels, out_channels=out_channels, kernel_size=1
- )
+ self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, lengths=None):
@@ -382,9 +369,7 @@
else:
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
self.tanh = nn.Tanh()
- self.conv = Conv1d(
- in_channels=attention_channels, out_channels=channels, kernel_size=1
- )
+ self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
def forward(self, x, lengths=None):
"""Calculates mean and std for a batch (input tensor).
@@ -398,9 +383,7 @@
def _compute_statistics(x, m, dim=2, eps=self.eps):
mean = (m * x).sum(dim)
- std = torch.sqrt(
- (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
- )
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
return mean, std
if lengths is None:
@@ -638,9 +621,12 @@
for i in range(num_chunk):
# B x C
st, ed = i * self.window_shift, i * self.window_shift + self.window_size
- x = self.asp(x[:, :, st: ed],
- lengths=torch.clamp(lengths - i, 0, self.window_size)
- if lengths is not None else None)
+ x = self.asp(
+ x[:, :, st:ed],
+ lengths=(
+ torch.clamp(lengths - i, 0, self.window_size) if lengths is not None else None
+ ),
+ )
x = self.asp_bn(x)
x = self.fc(x)
stat_list.append(x)
--
Gitblit v1.9.1