From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/whisper_lid/encoder.py |   42 ++++++++++++++++++------------------------
 1 files changed, 18 insertions(+), 24 deletions(-)

diff --git a/funasr/models/whisper_lid/encoder.py b/funasr/models/whisper_lid/encoder.py
index 7eeb643..b6fb28b 100644
--- a/funasr/models/whisper_lid/encoder.py
+++ b/funasr/models/whisper_lid/encoder.py
@@ -22,13 +22,13 @@
     """
 
     def __init__(
-            self,
-            dropout_rate: float = 0.0,
-            whisper_model: str = "small",
-            download_dir: str = None,
-            use_specaug: bool = False,
-            use_padmask: bool = False,
-            specaug_conf: Union[dict, None] = None,
+        self,
+        dropout_rate: float = 0.0,
+        whisper_model: str = "small",
+        download_dir: str = None,
+        use_specaug: bool = False,
+        use_padmask: bool = False,
+        specaug_conf: Union[dict, None] = None,
     ):
         super().__init__()
 
@@ -36,9 +36,7 @@
         self.dropout = torch.nn.Dropout(dropout_rate)
 
         assert whisper_model in whisper.available_models()
-        _model = whisper.load_model(
-            whisper_model, download_root=download_dir, device="cpu"
-        )
+        _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
         self.encoders = copy.deepcopy(_model.encoder)
         self.encoders.train()
 
@@ -51,9 +49,9 @@
         self.use_padmask = use_padmask
 
     def whisper_encode(
-            self,
-            input: torch.Tensor,
-            ilens: torch.Tensor = None,
+        self,
+        input: torch.Tensor,
+        ilens: torch.Tensor = None,
     ) -> torch.Tensor:
         x = F.gelu(self.encoders.conv1(input))
         x = F.gelu(self.encoders.conv2(x))
@@ -69,13 +67,9 @@
 
         if ilens is not None:
             olens = (
-                    1
-                    + (
-                            ilens
-                            - self.encoders.conv2.kernel_size[0]
-                            + 2 * self.encoders.conv2.padding[0]
-                    )
-                    // self.encoders.conv2.stride[0]
+                1
+                + (ilens - self.encoders.conv2.kernel_size[0] + 2 * self.encoders.conv2.padding[0])
+                // self.encoders.conv2.stride[0]
             )
             olens = torch.clamp(olens, max=max_pos)
         else:
@@ -102,10 +96,10 @@
         return self.encoders.conv2.weight.shape[0]
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         feats, feats_lens = xs_pad, ilens
 

--
Gitblit v1.9.1