From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/whisper_lid/eres2net/pooling_layers.py | 29 +++++++++++++++--------------
1 files changed, 15 insertions(+), 14 deletions(-)
diff --git a/funasr/models/whisper_lid/eres2net/pooling_layers.py b/funasr/models/whisper_lid/eres2net/pooling_layers.py
index f756ac8..da23d7b 100644
--- a/funasr/models/whisper_lid/eres2net/pooling_layers.py
+++ b/funasr/models/whisper_lid/eres2net/pooling_layers.py
@@ -57,7 +57,9 @@
count_without_padding = torch.sum(masks, axis=-1)
mean_without_padding = sum_without_padding / count_without_padding
- var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(-1) / count_without_padding
+ var_without_padding = ((x_masked - mean_without_padding.unsqueeze(-1)) ** 2 * masks).sum(
+ -1
+ ) / count_without_padding
pooling_mean = mean_without_padding
pooling_std = torch.sqrt(var_without_padding + 1e-8)
@@ -69,8 +71,8 @@
class ASTP(nn.Module):
- """ Attentive statistics pooling: Channel- and context-dependent
- statistics pooling, first used in ECAPA_TDNN.
+ """Attentive statistics pooling: Channel- and context-dependent
+ statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
@@ -81,14 +83,15 @@
# need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
- in_dim * 3, bottleneck_dim,
- kernel_size=1) # equals W and b in the paper
+ in_dim * 3, bottleneck_dim, kernel_size=1
+ ) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
- in_dim, bottleneck_dim,
- kernel_size=1) # equals W and b in the paper
- self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
- kernel_size=1) # equals V and k in the paper
+ in_dim, bottleneck_dim, kernel_size=1
+ ) # equals W and b in the paper
+ self.linear2 = nn.Conv1d(
+ bottleneck_dim, in_dim, kernel_size=1
+ ) # equals V and k in the paper
def forward(self, x):
"""
@@ -102,17 +105,15 @@
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
- context_std = torch.sqrt(
- torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! ReLU may be hard to converge.
- alpha = torch.tanh(
- self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
+ alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
- var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-10))
return torch.cat([mean, std], dim=1)
--
Gitblit v1.9.1