| | |
| | | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. |
| | | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker) |
| | | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | import torch.utils.checkpoint as cp |
| | | from torch import nn |
| | | |
| | | |
| | | class BasicResBlock(nn.Module): |
| | | class BasicResBlock(torch.nn.Module): |
| | | expansion = 1 |
| | | |
| | | def __init__(self, in_planes, planes, stride=1): |
| | | super(BasicResBlock, self).__init__() |
| | | self.conv1 = nn.Conv2d(in_planes, |
| | | self.conv1 = torch.nn.Conv2d(in_planes, |
| | | planes, |
| | | kernel_size=3, |
| | | stride=(stride, 1), |
| | | padding=1, |
| | | bias=False) |
| | | self.bn1 = nn.BatchNorm2d(planes) |
| | | self.conv2 = nn.Conv2d(planes, |
| | | self.bn1 = torch.nn.BatchNorm2d(planes) |
| | | self.conv2 = torch.nn.Conv2d(planes, |
| | | planes, |
| | | kernel_size=3, |
| | | stride=1, |
| | | padding=1, |
| | | bias=False) |
| | | self.bn2 = nn.BatchNorm2d(planes) |
| | | self.bn2 = torch.nn.BatchNorm2d(planes) |
| | | |
| | | self.shortcut = nn.Sequential() |
| | | self.shortcut = torch.nn.Sequential() |
| | | if stride != 1 or in_planes != self.expansion * planes: |
| | | self.shortcut = nn.Sequential( |
| | | nn.Conv2d(in_planes, |
| | | self.shortcut = torch.nn.Sequential( |
| | | torch.nn.Conv2d(in_planes, |
| | | self.expansion * planes, |
| | | kernel_size=1, |
| | | stride=(stride, 1), |
| | | bias=False), |
| | | nn.BatchNorm2d(self.expansion * planes)) |
| | | torch.nn.BatchNorm2d(self.expansion * planes)) |
| | | |
| | | def forward(self, x): |
| | | out = F.relu(self.bn1(self.conv1(x))) |
| | |
| | | return out |
| | | |
| | | |
| | | class FCM(nn.Module): |
| | | class FCM(torch.nn.Module): |
| | | def __init__(self, |
| | | block=BasicResBlock, |
| | | num_blocks=[2, 2], |
| | |
| | | feat_dim=80): |
| | | super(FCM, self).__init__() |
| | | self.in_planes = m_channels |
| | | self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) |
| | | self.bn1 = nn.BatchNorm2d(m_channels) |
| | | self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) |
| | | self.bn1 = torch.nn.BatchNorm2d(m_channels) |
| | | |
| | | self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) |
| | | self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2) |
| | | |
| | | self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) |
| | | self.bn2 = nn.BatchNorm2d(m_channels) |
| | | self.conv2 = torch.nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False) |
| | | self.bn2 = torch.nn.BatchNorm2d(m_channels) |
| | | self.out_channels = m_channels * (feat_dim // 8) |
| | | |
| | | def _make_layer(self, block, planes, num_blocks, stride): |
| | |
| | | for stride in strides: |
| | | layers.append(block(self.in_planes, planes, stride)) |
| | | self.in_planes = planes * block.expansion |
| | | return nn.Sequential(*layers) |
| | | return torch.nn.Sequential(*layers) |
| | | |
| | | def forward(self, x): |
| | | x = x.unsqueeze(1) |
| | |
| | | |
| | | |
| | | def get_nonlinear(config_str, channels): |
| | | nonlinear = nn.Sequential() |
| | | nonlinear = torch.nn.Sequential() |
| | | for name in config_str.split('-'): |
| | | if name == 'relu': |
| | | nonlinear.add_module('relu', nn.ReLU(inplace=True)) |
| | | nonlinear.add_module('relu', torch.nn.ReLU(inplace=True)) |
| | | elif name == 'prelu': |
| | | nonlinear.add_module('prelu', nn.PReLU(channels)) |
| | | nonlinear.add_module('prelu', torch.nn.PReLU(channels)) |
| | | elif name == 'batchnorm': |
| | | nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels)) |
| | | nonlinear.add_module('batchnorm', torch.nn.BatchNorm1d(channels)) |
| | | elif name == 'batchnorm_': |
| | | nonlinear.add_module('batchnorm', |
| | | nn.BatchNorm1d(channels, affine=False)) |
| | | torch.nn.BatchNorm1d(channels, affine=False)) |
| | | else: |
| | | raise ValueError('Unexpected module ({}).'.format(name)) |
| | | return nonlinear |
| | |
| | | return stats |
| | | |
| | | |
| | | class StatsPool(nn.Module): |
| | | class StatsPool(torch.nn.Module): |
| | | def forward(self, x): |
| | | return statistics_pooling(x) |
| | | |
| | | |
| | | class TDNNLayer(nn.Module): |
| | | class TDNNLayer(torch.nn.Module): |
| | | def __init__(self, |
| | | in_channels, |
| | | out_channels, |
| | |
| | | assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format( |
| | | kernel_size) |
| | | padding = (kernel_size - 1) // 2 * dilation |
| | | self.linear = nn.Conv1d(in_channels, |
| | | self.linear = torch.nn.Conv1d(in_channels, |
| | | out_channels, |
| | | kernel_size, |
| | | stride=stride, |
| | |
| | | return x |
| | | |
| | | |
| | | class CAMLayer(nn.Module): |
| | | class CAMLayer(torch.nn.Module): |
| | | def __init__(self, |
| | | bn_channels, |
| | | out_channels, |
| | |
| | | bias, |
| | | reduction=2): |
| | | super(CAMLayer, self).__init__() |
| | | self.linear_local = nn.Conv1d(bn_channels, |
| | | self.linear_local = torch.nn.Conv1d(bn_channels, |
| | | out_channels, |
| | | kernel_size, |
| | | stride=stride, |
| | | padding=padding, |
| | | dilation=dilation, |
| | | bias=bias) |
| | | self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1) |
| | | self.relu = nn.ReLU(inplace=True) |
| | | self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1) |
| | | self.sigmoid = nn.Sigmoid() |
| | | self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1) |
| | | self.relu = torch.nn.ReLU(inplace=True) |
| | | self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1) |
| | | self.sigmoid = torch.nn.Sigmoid() |
| | | |
| | | def forward(self, x): |
| | | y = self.linear_local(x) |
| | |
| | | return seg |
| | | |
| | | |
| | | class CAMDenseTDNNLayer(nn.Module): |
| | | class CAMDenseTDNNLayer(torch.nn.Module): |
| | | def __init__(self, |
| | | in_channels, |
| | | out_channels, |
| | |
| | | padding = (kernel_size - 1) // 2 * dilation |
| | | self.memory_efficient = memory_efficient |
| | | self.nonlinear1 = get_nonlinear(config_str, in_channels) |
| | | self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False) |
| | | self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False) |
| | | self.nonlinear2 = get_nonlinear(config_str, bn_channels) |
| | | self.cam_layer = CAMLayer(bn_channels, |
| | | out_channels, |
| | |
| | | return x |
| | | |
| | | |
| | | class CAMDenseTDNNBlock(nn.ModuleList): |
| | | class CAMDenseTDNNBlock(torch.nn.ModuleList): |
| | | def __init__(self, |
| | | num_layers, |
| | | in_channels, |
| | |
| | | return x |
| | | |
| | | |
| | | class TransitLayer(nn.Module): |
| | | class TransitLayer(torch.nn.Module): |
| | | def __init__(self, |
| | | in_channels, |
| | | out_channels, |
| | |
| | | config_str='batchnorm-relu'): |
| | | super(TransitLayer, self).__init__() |
| | | self.nonlinear = get_nonlinear(config_str, in_channels) |
| | | self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) |
| | | self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) |
| | | |
| | | def forward(self, x): |
| | | x = self.nonlinear(x) |
| | |
| | | return x |
| | | |
| | | |
| | | class DenseLayer(nn.Module): |
| | | class DenseLayer(torch.nn.Module): |
| | | def __init__(self, |
| | | in_channels, |
| | | out_channels, |
| | | bias=False, |
| | | config_str='batchnorm-relu'): |
| | | super(DenseLayer, self).__init__() |
| | | self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias) |
| | | self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) |
| | | self.nonlinear = get_nonlinear(config_str, out_channels) |
| | | |
| | | def forward(self, x): |