From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/modules/subsampling.py |   24 +++++++++++++-----------
 1 files changed, 13 insertions(+), 11 deletions(-)

diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index a2b91a7..af33aef 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -358,7 +358,8 @@
         """
         x = x.transpose(1, 2)  # (b, d ,t)
         x = self.pad_fn(x)
-        x = F.relu(self.conv(x))
+        #x = F.relu(self.conv(x))
+        x = F.leaky_relu(self.conv(x), negative_slope=0.)
         x = x.transpose(1, 2)  # (b, t ,d)
 
         if x_len is None:
@@ -427,6 +428,7 @@
         conv_size: Union[int, Tuple],
         subsampling_factor: int = 4,
         vgg_like: bool = True,
+        conv_kernel_size: int = 3,
         output_size: Optional[int] = None,
     ) -> None:
         """Construct a ConvInput object."""
@@ -436,14 +438,14 @@
                 conv_size1, conv_size2 = conv_size
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((1, 2)),
                 )
@@ -462,14 +464,14 @@
                 kernel_1 = int(subsampling_factor / 2)
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((kernel_1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((2, 2)),
                 )
@@ -487,14 +489,14 @@
                 self.conv = torch.nn.Sequential(
                     torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+                    torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
                     torch.nn.ReLU(),
                 )
 
                 output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
 
                 self.subsampling_factor = subsampling_factor
-                self.kernel_2 = 3
+                self.kernel_2 = conv_kernel_size
                 self.stride_2 = 1
 
                 self.create_new_mask = self.create_new_conv2d_mask

--
Gitblit v1.9.1