From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2

---
 funasr/models/encoder/mfcca_encoder.py |  120 +++++++++++++++++++++++++++---------------------------------
 1 files changed, 54 insertions(+), 66 deletions(-)

diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
index 83d0b0e..95ccf07 100644
--- a/funasr/models/encoder/mfcca_encoder.py
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -38,13 +38,12 @@
 import pdb
 import math
 
+
 class ConvolutionModule(nn.Module):
     """ConvolutionModule in Conformer model.
-
     Args:
         channels (int): The number of channels of conv layers.
         kernel_size (int): Kernerl size of conv layers.
-
     """
 
     def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
@@ -83,13 +82,10 @@
 
     def forward(self, x):
         """Compute convolution module.
-
         Args:
             x (torch.Tensor): Input tensor (#batch, time, channels).
-
         Returns:
             torch.Tensor: Output tensor (#batch, time, channels).
-
         """
         # exchange the temporal dimension and the feature dimension
         x = x.transpose(1, 2)
@@ -107,10 +103,8 @@
         return x.transpose(1, 2)
 
 
-
 class MFCCAEncoder(AbsEncoder):
     """Conformer encoder module.
-
     Args:
         input_size (int): Input dimension.
         output_size (int): Dimention of attention.
@@ -140,33 +134,32 @@
         zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
         cnn_module_kernel (int): Kernerl size of convolution module.
         padding_idx (int): Padding idx for input_layer=embed.
-
     """
 
     def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        input_layer: str = "conv2d",
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 3,
-        macaron_style: bool = False,
-        rel_pos_type: str = "legacy",
-        pos_enc_layer_type: str = "rel_pos",
-        selfattention_layer_type: str = "rel_selfattn",
-        activation_type: str = "swish",
-        use_cnn_module: bool = True,
-        zero_triu: bool = False,
-        cnn_module_kernel: int = 31,
-        padding_idx: int = -1,
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            input_layer: str = "conv2d",
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 3,
+            macaron_style: bool = False,
+            rel_pos_type: str = "legacy",
+            pos_enc_layer_type: str = "rel_pos",
+            selfattention_layer_type: str = "rel_selfattn",
+            activation_type: str = "swish",
+            use_cnn_module: bool = True,
+            zero_triu: bool = False,
+            cnn_module_kernel: int = 31,
+            padding_idx: int = -1,
     ):
         assert check_argument_types()
         super().__init__()
@@ -199,7 +192,7 @@
             )
         else:
             raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
-        
+
         if input_layer == "linear":
             self.embed = torch.nn.Sequential(
                 torch.nn.Linear(input_size, output_size),
@@ -283,7 +276,7 @@
             assert pos_enc_layer_type == "legacy_rel_pos"
             encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
             encoder_selfattn_layer_args = (
-               attention_heads,
+                attention_heads,
                 output_size,
                 attention_dropout_rate,
             )
@@ -326,42 +319,39 @@
         )
         if self.normalize_before:
             self.after_norm = LayerNorm(output_size)
-        self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
+        self.conv1 = torch.nn.Conv2d(8, 16, [5, 7], stride=[1, 1], padding=(2, 3))
 
-        self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
+        self.conv2 = torch.nn.Conv2d(16, 32, [5, 7], stride=[1, 1], padding=(2, 3))
 
-        self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
+        self.conv3 = torch.nn.Conv2d(32, 16, [5, 7], stride=[1, 1], padding=(2, 3))
 
-        self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
+        self.conv4 = torch.nn.Conv2d(16, 1, [5, 7], stride=[1, 1], padding=(2, 3))
 
     def output_size(self) -> int:
         return self._output_size
 
     def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        channel_size: torch.Tensor,
-        prev_states: torch.Tensor = None,
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            channel_size: torch.Tensor,
+            prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Calculate forward propagation.
-
         Args:
             xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
             ilens (torch.Tensor): Input length (#batch).
             prev_states (torch.Tensor): Not to be used now.
-
         Returns:
             torch.Tensor: Output tensor (#batch, L, output_size).
             torch.Tensor: Output length (#batch).
             torch.Tensor: Not to be used now.
-
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
         if (
-            isinstance(self.embed, Conv2dSubsampling)
-            or isinstance(self.embed, Conv2dSubsampling6)
-            or isinstance(self.embed, Conv2dSubsampling8)
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -380,48 +370,46 @@
 
         t_leng = xs_pad.size(1)
         d_dim = xs_pad.size(2)
-        xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
-        #pdb.set_trace()
-        if(channel_size<8):
-            repeat_num = math.ceil(8/channel_size)
-            xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
+        xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
+        # pdb.set_trace()
+        if (channel_size < 8):
+            repeat_num = math.ceil(8 / channel_size)
+            xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
         xs_pad = self.conv1(xs_pad)
         xs_pad = self.conv2(xs_pad)
         xs_pad = self.conv3(xs_pad)
         xs_pad = self.conv4(xs_pad)
-        xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
+        xs_pad = xs_pad.squeeze().reshape(-1, t_leng, d_dim)
         mask_tmp = masks.size(1)
-        masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
+        masks = masks.reshape(-1, channel_size, mask_tmp, t_leng)[:, 0, :, :]
 
         if self.normalize_before:
             xs_pad = self.after_norm(xs_pad)
 
         olens = masks.squeeze(1).sum(1)
         return xs_pad, olens, None
+
     def forward_hidden(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Calculate forward propagation.
-
         Args:
             xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
             ilens (torch.Tensor): Input length (#batch).
             prev_states (torch.Tensor): Not to be used now.
-
         Returns:
             torch.Tensor: Output tensor (#batch, L, output_size).
             torch.Tensor: Output length (#batch).
             torch.Tensor: Not to be used now.
-
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
         if (
-            isinstance(self.embed, Conv2dSubsampling)
-            or isinstance(self.embed, Conv2dSubsampling6)
-            or isinstance(self.embed, Conv2dSubsampling8)
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -447,4 +435,4 @@
             self.hidden_feature = self.after_norm(hidden_feature)
 
         olens = masks.squeeze(1).sum(1)
-        return xs_pad, olens, None
+        return xs_pad, olens, None
\ No newline at end of file

--
Gitblit v1.9.1