| | |
| | | |
| | | |
| | | class ASTP(nn.Module): |
| | | """ Attentive statistics pooling: Channel- and context-dependent |
| | | statistics pooling, first used in ECAPA_TDNN. |
| | | """Attentive statistics pooling: Channel- and context-dependent |
| | | statistics pooling, first used in ECAPA_TDNN. |
| | | """ |
| | | |
| | | def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False): |
| | |
| | | # need to transpose inputs. |
| | | if global_context_att: |
| | | self.linear1 = nn.Conv1d( |
| | | in_dim * 3, bottleneck_dim, |
| | | kernel_size=1) # equals W and b in the paper |
| | | in_dim * 3, bottleneck_dim, kernel_size=1 |
| | | ) # equals W and b in the paper |
| | | else: |
| | | self.linear1 = nn.Conv1d( |
| | | in_dim, bottleneck_dim, |
| | | kernel_size=1) # equals W and b in the paper |
| | | self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, |
| | | kernel_size=1) # equals V and k in the paper |
| | | in_dim, bottleneck_dim, kernel_size=1 |
| | | ) # equals W and b in the paper |
| | | self.linear2 = nn.Conv1d( |
| | | bottleneck_dim, in_dim, kernel_size=1 |
| | | ) # equals V and k in the paper |
| | | |
| | | def forward(self, x): |
| | | """ |
| | |
| | | |
| | | if self.global_context_att: |
| | | context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) |
| | | context_std = torch.sqrt( |
| | | torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) |
| | | context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) |
| | | x_in = torch.cat((x, context_mean, context_std), dim=1) |
| | | else: |
| | | x_in = x |
| | | |
| | | # DON'T use ReLU here! ReLU may be hard to converge. |
| | | alpha = torch.tanh( |
| | | self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) |
| | | alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) |
| | | alpha = torch.softmax(self.linear2(alpha), dim=2) |
| | | mean = torch.sum(alpha * x, dim=2) |
| | | var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2 |
| | | var = torch.sum(alpha * (x**2), dim=2) - mean**2 |
| | | std = torch.sqrt(var.clamp(min=1e-10)) |
| | | return torch.cat([mean, std], dim=1) |