From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/whisper_lid/encoder.py | 42 ++++++++++++++++++------------------------
1 files changed, 18 insertions(+), 24 deletions(-)
diff --git a/funasr/models/whisper_lid/encoder.py b/funasr/models/whisper_lid/encoder.py
index 7eeb643..b6fb28b 100644
--- a/funasr/models/whisper_lid/encoder.py
+++ b/funasr/models/whisper_lid/encoder.py
@@ -22,13 +22,13 @@
"""
def __init__(
- self,
- dropout_rate: float = 0.0,
- whisper_model: str = "small",
- download_dir: str = None,
- use_specaug: bool = False,
- use_padmask: bool = False,
- specaug_conf: Union[dict, None] = None,
+ self,
+ dropout_rate: float = 0.0,
+ whisper_model: str = "small",
+ download_dir: str = None,
+ use_specaug: bool = False,
+ use_padmask: bool = False,
+ specaug_conf: Union[dict, None] = None,
):
super().__init__()
@@ -36,9 +36,7 @@
self.dropout = torch.nn.Dropout(dropout_rate)
assert whisper_model in whisper.available_models()
- _model = whisper.load_model(
- whisper_model, download_root=download_dir, device="cpu"
- )
+ _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
self.encoders = copy.deepcopy(_model.encoder)
self.encoders.train()
@@ -51,9 +49,9 @@
self.use_padmask = use_padmask
def whisper_encode(
- self,
- input: torch.Tensor,
- ilens: torch.Tensor = None,
+ self,
+ input: torch.Tensor,
+ ilens: torch.Tensor = None,
) -> torch.Tensor:
x = F.gelu(self.encoders.conv1(input))
x = F.gelu(self.encoders.conv2(x))
@@ -69,13 +67,9 @@
if ilens is not None:
olens = (
- 1
- + (
- ilens
- - self.encoders.conv2.kernel_size[0]
- + 2 * self.encoders.conv2.padding[0]
- )
- // self.encoders.conv2.stride[0]
+ 1
+ + (ilens - self.encoders.conv2.kernel_size[0] + 2 * self.encoders.conv2.padding[0])
+ // self.encoders.conv2.stride[0]
)
olens = torch.clamp(olens, max=max_pos)
else:
@@ -102,10 +96,10 @@
return self.encoders.conv2.weight.shape[0]
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
feats, feats_lens = xs_pad, ilens
--
Gitblit v1.9.1