| | |
| | | if x.ndim == 3: |
| | | x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) |
| | | else: |
| | | x = x.reshape( |
| | | shape_or[0] * shape_or[1], shape_or[3], shape_or[2] |
| | | ) |
| | | x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2]) |
| | | |
| | | elif not self.skip_transpose: |
| | | x = x.transpose(-1, 1) |
| | |
| | | x = x.unsqueeze(1) |
| | | |
| | | if self.padding == "same": |
| | | x = self._manage_padding( |
| | | x, self.kernel_size, self.dilation, self.stride |
| | | ) |
| | | x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) |
| | | |
| | | elif self.padding == "causal": |
| | | num_pad = (self.kernel_size - 1) * self.dilation |
| | |
| | | pass |
| | | |
| | | else: |
| | | raise ValueError( |
| | | "Padding must be 'same', 'valid' or 'causal'. Got " |
| | | + self.padding |
| | | ) |
| | | raise ValueError("Padding must be 'same', 'valid' or 'causal'. Got " + self.padding) |
| | | |
| | | wx = self.conv(x) |
| | | |
| | |
| | | return wx |
| | | |
| | | def _manage_padding( |
| | | self, x, kernel_size: int, dilation: int, stride: int, |
| | | self, |
| | | x, |
| | | kernel_size: int, |
| | | dilation: int, |
| | | stride: int, |
| | | ): |
| | | # Detecting input shape |
| | | L_in = x.shape[-1] |
| | |
| | | return x |
| | | |
| | | def _check_input_shape(self, shape): |
| | | """Checks the input shape and returns the number of input channels. |
| | | """ |
| | | """Checks the input shape and returns the number of input channels.""" |
| | | |
| | | if len(shape) == 2: |
| | | self.unsqueeze = True |
| | |
| | | elif len(shape) == 3: |
| | | in_channels = shape[2] |
| | | else: |
| | | raise ValueError( |
| | | "conv1d expects 2d, 3d inputs. Got " + str(len(shape)) |
| | | ) |
| | | raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape))) |
| | | |
| | | # Kernel size must be odd |
| | | if self.kernel_size % 2 == 0: |
| | | raise ValueError( |
| | | "The field kernel size must be an odd number. Got %s." |
| | | % (self.kernel_size) |
| | | "The field kernel size must be an odd number. Got %s." % (self.kernel_size) |
| | | ) |
| | | return in_channels |
| | | |
| | |
| | | |
| | | if max_len is None: |
| | | max_len = length.max().long().item() # using arange to generate mask |
| | | mask = torch.arange( |
| | | max_len, device=length.device, dtype=length.dtype |
| | | ).expand(len(length), max_len) < length.unsqueeze(1) |
| | | mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( |
| | | len(length), max_len |
| | | ) < length.unsqueeze(1) |
| | | |
| | | if dtype is None: |
| | | dtype = length.dtype |
| | |
| | | torch.Size([8, 120, 64]) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1 |
| | | ): |
| | | def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): |
| | | super(Res2NetBlock, self).__init__() |
| | | assert in_channels % scale == 0 |
| | | assert out_channels % scale == 0 |
| | |
| | | def __init__(self, in_channels, se_channels, out_channels): |
| | | super(SEBlock, self).__init__() |
| | | |
| | | self.conv1 = Conv1d( |
| | | in_channels=in_channels, out_channels=se_channels, kernel_size=1 |
| | | ) |
| | | self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1) |
| | | self.relu = torch.nn.ReLU(inplace=True) |
| | | self.conv2 = Conv1d( |
| | | in_channels=se_channels, out_channels=out_channels, kernel_size=1 |
| | | ) |
| | | self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1) |
| | | self.sigmoid = torch.nn.Sigmoid() |
| | | |
| | | def forward(self, x, lengths=None): |
| | |
| | | else: |
| | | self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) |
| | | self.tanh = nn.Tanh() |
| | | self.conv = Conv1d( |
| | | in_channels=attention_channels, out_channels=channels, kernel_size=1 |
| | | ) |
| | | self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1) |
| | | |
| | | def forward(self, x, lengths=None): |
| | | """Calculates mean and std for a batch (input tensor). |
| | |
| | | |
| | | def _compute_statistics(x, m, dim=2, eps=self.eps): |
| | | mean = (m * x).sum(dim) |
| | | std = torch.sqrt( |
| | | (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps) |
| | | ) |
| | | std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) |
| | | return mean, std |
| | | |
| | | if lengths is None: |
| | |
| | | for i in range(num_chunk): |
| | | # B x C |
| | | st, ed = i * self.window_shift, i * self.window_shift + self.window_size |
| | | x = self.asp(x[:, :, st: ed], |
| | | lengths=torch.clamp(lengths - i, 0, self.window_size) |
| | | if lengths is not None else None) |
| | | x = self.asp( |
| | | x[:, :, st:ed], |
| | | lengths=( |
| | | torch.clamp(lengths - i, 0, self.window_size) if lengths is not None else None |
| | | ), |
| | | ) |
| | | x = self.asp_bn(x) |
| | | x = self.fc(x) |
| | | stat_list.append(x) |