游雁
2024-04-29 2779602177ae5374547c7a7e17de0b11a166326d
funasr/models/transformer/utils/dynamic_conv2d.py
@@ -95,9 +95,7 @@
        # convolution of frequency axis
        weight_f = self.linear_weight_f(x).view(B * T, 1, k)  # B x T x k
        self.attn_f = weight_f.view(B, T, k).unsqueeze(1)
        xf = F.conv1d(
            x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T
        )
        xf = F.conv1d(x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T)
        xf = xf.view(B, T, C)
        # get kernel of convolution
@@ -107,9 +105,9 @@
        weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype)
        weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf"))
        weight_new = weight_new.to(x.device)  # B x H x T x T+k-1
        weight_new.as_strided(
            (B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)
        ).copy_(weight)
        weight_new.as_strided((B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)).copy_(
            weight
        )
        weight_new = weight_new.narrow(-1, int((k - 1) / 2), T)  # B x H x T x T(k)
        if self.use_kernel_mask:
            kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)