From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/mfcca/mfcca_encoder.py |  108 +++++++++++++++++++++++++-----------------------------
 1 files changed, 50 insertions(+), 58 deletions(-)

diff --git a/funasr/models/mfcca/mfcca_encoder.py b/funasr/models/mfcca/mfcca_encoder.py
index 92dd6e7..a0bb58e 100644
--- a/funasr/models/mfcca/mfcca_encoder.py
+++ b/funasr/models/mfcca/mfcca_encoder.py
@@ -26,15 +26,14 @@
 from funasr.models.transformer.positionwise_feed_forward import (
     PositionwiseFeedForward,  # noqa: H301
 )
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
 from funasr.models.encoder.abs_encoder import AbsEncoder
-import pdb
 import math
 
 
@@ -136,29 +135,29 @@
     """
 
     def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: str = "conv2d",
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 3,
-            macaron_style: bool = False,
-            rel_pos_type: str = "legacy",
-            pos_enc_layer_type: str = "rel_pos",
-            selfattention_layer_type: str = "rel_selfattn",
-            activation_type: str = "swish",
-            use_cnn_module: bool = True,
-            zero_triu: bool = False,
-            cnn_module_kernel: int = 31,
-            padding_idx: int = -1,
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: str = "conv2d",
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 3,
+        macaron_style: bool = False,
+        rel_pos_type: str = "legacy",
+        pos_enc_layer_type: str = "rel_pos",
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        zero_triu: bool = False,
+        cnn_module_kernel: int = 31,
+        padding_idx: int = -1,
     ):
         super().__init__()
         self._output_size = output_size
@@ -185,9 +184,7 @@
         elif pos_enc_layer_type == "legacy_rel_pos":
             assert selfattention_layer_type == "legacy_rel_selfattn"
             pos_enc_class = LegacyRelPositionalEncoding
-            logging.warning(
-                "Using legacy_rel_pos and it will be deprecated in the future."
-            )
+            logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
         else:
             raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
 
@@ -230,9 +227,7 @@
                 pos_enc_class(output_size, positional_dropout_rate),
             )
         elif input_layer is None:
-            self.embed = torch.nn.Sequential(
-                pos_enc_class(output_size, positional_dropout_rate)
-            )
+            self.embed = torch.nn.Sequential(pos_enc_class(output_size, positional_dropout_rate))
         else:
             raise ValueError("unknown input_layer: " + input_layer)
         self.normalize_before = normalize_before
@@ -278,9 +273,7 @@
                 output_size,
                 attention_dropout_rate,
             )
-            logging.warning(
-                "Using legacy_rel_selfattn and it will be deprecated in the future."
-            )
+            logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
         elif selfattention_layer_type == "rel_selfattn":
             assert pos_enc_layer_type == "rel_pos"
             encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -329,11 +322,11 @@
         return self._output_size
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            channel_size: torch.Tensor,
-            prev_states: torch.Tensor = None,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        channel_size: torch.Tensor,
+        prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Calculate forward propagation.
         Args:
@@ -347,9 +340,9 @@
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
         if (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -369,8 +362,7 @@
         t_leng = xs_pad.size(1)
         d_dim = xs_pad.size(2)
         xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
-        # pdb.set_trace()
-        if (channel_size < 8):
+        if channel_size < 8:
             repeat_num = math.ceil(8 / channel_size)
             xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
         xs_pad = self.conv1(xs_pad)
@@ -388,10 +380,10 @@
         return xs_pad, olens, None
 
     def forward_hidden(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Calculate forward propagation.
         Args:
@@ -405,9 +397,9 @@
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
         if (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -433,4 +425,4 @@
             self.hidden_feature = self.after_norm(hidden_feature)
 
         olens = masks.squeeze(1).sum(1)
-        return xs_pad, olens, None
\ No newline at end of file
+        return xs_pad, olens, None

--
Gitblit v1.9.1