zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/rwkv_bat/rwkv_subsampling.py
@@ -62,17 +62,49 @@
            conv_size1, conv_size2, conv_size3 = conv_size
            self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size - 1) // 2
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size1,
                    conv_size1,
                    conv_kernel_size,
                    stride=[1, 2],
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size1,
                    conv_size2,
                    conv_kernel_size,
                    stride=1,
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size2,
                    conv_size2,
                    conv_kernel_size,
                    stride=[1, 2],
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size2,
                    conv_size3,
                    conv_kernel_size,
                    stride=1,
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size3,
                    conv_size3,
                    conv_kernel_size,
                    stride=[1, 2],
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
            )
@@ -90,17 +122,49 @@
            kernel_1 = int(subsampling_factor / 2)
            self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size - 1) // 2
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[kernel_1, 2], padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size1,
                    conv_size1,
                    conv_kernel_size,
                    stride=[kernel_1, 2],
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size1,
                    conv_size2,
                    conv_kernel_size,
                    stride=1,
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[2, 2], padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size2,
                    conv_size2,
                    conv_kernel_size,
                    stride=[2, 2],
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size2,
                    conv_size3,
                    conv_kernel_size,
                    stride=1,
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                torch.nn.Conv2d(
                    conv_size3,
                    conv_size3,
                    conv_kernel_size,
                    stride=1,
                    padding=(conv_kernel_size - 1) // 2,
                ),
                    torch.nn.ReLU(),
            )
@@ -141,7 +205,9 @@
        if chunk_size is not None:
            max_input_length = int(
                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
                chunk_size
                * self.subsampling_factor
                * (math.ceil(float(t) / (chunk_size * self.subsampling_factor)))
            )
            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
            x = list(x)