liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/sond/encoder/ecapa_tdnn_encoder.py
@@ -39,9 +39,7 @@
            if x.ndim == 3:
                x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
            else:
                x = x.reshape(
                    shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
                )
                x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2])
        elif not self.skip_transpose:
            x = x.transpose(-1, 1)
@@ -105,9 +103,7 @@
            x = x.unsqueeze(1)
        if self.padding == "same":
            x = self._manage_padding(
                x, self.kernel_size, self.dilation, self.stride
            )
            x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
        elif self.padding == "causal":
            num_pad = (self.kernel_size - 1) * self.dilation
@@ -117,10 +113,7 @@
            pass
        else:
            raise ValueError(
                "Padding must be 'same', 'valid' or 'causal'. Got "
                + self.padding
            )
            raise ValueError("Padding must be 'same', 'valid' or 'causal'. Got " + self.padding)
        wx = self.conv(x)
@@ -133,7 +126,11 @@
        return wx
    def _manage_padding(
        self, x, kernel_size: int, dilation: int, stride: int,
        self,
        x,
        kernel_size: int,
        dilation: int,
        stride: int,
    ):
        # Detecting input shape
        L_in = x.shape[-1]
@@ -147,8 +144,7 @@
        return x
    def _check_input_shape(self, shape):
        """Checks the input shape and returns the number of input channels.
        """
        """Checks the input shape and returns the number of input channels."""
        if len(shape) == 2:
            self.unsqueeze = True
@@ -158,15 +154,12 @@
        elif len(shape) == 3:
            in_channels = shape[2]
        else:
            raise ValueError(
                "conv1d expects 2d, 3d inputs. Got " + str(len(shape))
            )
            raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape)))
        # Kernel size must be odd
        if self.kernel_size % 2 == 0:
            raise ValueError(
                "The field kernel size must be an odd number. Got %s."
                % (self.kernel_size)
                "The field kernel size must be an odd number. Got %s." % (self.kernel_size)
            )
        return in_channels
@@ -200,9 +193,9 @@
    if max_len is None:
        max_len = length.max().long().item()  # using arange to generate mask
    mask = torch.arange(
        max_len, device=length.device, dtype=length.dtype
    ).expand(len(length), max_len) < length.unsqueeze(1)
    mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
        len(length), max_len
    ) < length.unsqueeze(1)
    if dtype is None:
        dtype = length.dtype
@@ -264,9 +257,7 @@
    torch.Size([8, 120, 64])
    """
    def __init__(
        self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
    ):
    def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
        super(Res2NetBlock, self).__init__()
        assert in_channels % scale == 0
        assert out_channels % scale == 0
@@ -326,13 +317,9 @@
    def __init__(self, in_channels, se_channels, out_channels):
        super(SEBlock, self).__init__()
        self.conv1 = Conv1d(
            in_channels=in_channels, out_channels=se_channels, kernel_size=1
        )
        self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
        self.relu = torch.nn.ReLU(inplace=True)
        self.conv2 = Conv1d(
            in_channels=se_channels, out_channels=out_channels, kernel_size=1
        )
        self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
        self.sigmoid = torch.nn.Sigmoid()
    def forward(self, x, lengths=None):
@@ -382,9 +369,7 @@
        else:
            self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
        self.tanh = nn.Tanh()
        self.conv = Conv1d(
            in_channels=attention_channels, out_channels=channels, kernel_size=1
        )
        self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
    def forward(self, x, lengths=None):
        """Calculates mean and std for a batch (input tensor).
@@ -398,9 +383,7 @@
        def _compute_statistics(x, m, dim=2, eps=self.eps):
            mean = (m * x).sum(dim)
            std = torch.sqrt(
                (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
            )
            std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
            return mean, std
        if lengths is None:
@@ -638,9 +621,12 @@
        for i in range(num_chunk):
            # B x C
            st, ed = i * self.window_shift, i * self.window_shift + self.window_size
            x = self.asp(x[:, :, st: ed],
                         lengths=torch.clamp(lengths - i, 0, self.window_size)
                         if lengths is not None else None)
            x = self.asp(
                x[:, :, st:ed],
                lengths=(
                    torch.clamp(lengths - i, 0, self.window_size) if lengths is not None else None
                ),
            )
            x = self.asp_bn(x)
            x = self.fc(x)
            stat_list.append(x)