majic31
2024-12-24 23e7ddebccd3b05cf7ef89809bcfe565ad6dfa1f
funasr/models/transformer/utils/lightconv2d.py
@@ -50,9 +50,7 @@
        self.act = nn.GLU()
        # lightconv related
        self.weight = nn.Parameter(
            torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
        )
        self.weight = nn.Parameter(torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1))
        self.weight_f = nn.Parameter(torch.Tensor(1, 1, kernel_size).uniform_(0, 1))
        self.use_bias = use_bias
        if self.use_bias:
@@ -93,9 +91,9 @@
        # convolution along frequency axis
        weight_f = F.softmax(self.weight_f, dim=-1)
        weight_f = F.dropout(weight_f, self.dropout_rate, training=self.training)
        weight_new = torch.zeros(
            B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype
        ).copy_(weight_f)
        weight_new = torch.zeros(B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype).copy_(
            weight_f
        )
        xf = F.conv1d(
            x.view(1, B * T, C), weight_new, padding=self.padding_size, groups=B * T
        ).view(B, T, C)
@@ -107,9 +105,7 @@
            self.kernel_mask = self.kernel_mask.to(x.device)
            weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
        weight = F.softmax(weight, dim=-1)
        x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
            B, C, T
        )
        x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(B, C, T)
        if self.use_bias:
            x = x + self.bias.view(1, -1, 1)
        x = x.transpose(1, 2)  # B x T x C