From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/mfcca/mfcca_encoder.py | 108 +++++++++++++++++++++++++-----------------------------
1 files changed, 50 insertions(+), 58 deletions(-)
diff --git a/funasr/models/mfcca/mfcca_encoder.py b/funasr/models/mfcca/mfcca_encoder.py
index 92dd6e7..a0bb58e 100644
--- a/funasr/models/mfcca/mfcca_encoder.py
+++ b/funasr/models/mfcca/mfcca_encoder.py
@@ -26,15 +26,14 @@
from funasr.models.transformer.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
-import pdb
import math
@@ -136,29 +135,29 @@
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: str = "conv2d",
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 3,
- macaron_style: bool = False,
- rel_pos_type: str = "legacy",
- pos_enc_layer_type: str = "rel_pos",
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- zero_triu: bool = False,
- cnn_module_kernel: int = 31,
- padding_idx: int = -1,
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ cnn_module_kernel: int = 31,
+ padding_idx: int = -1,
):
super().__init__()
self._output_size = output_size
@@ -185,9 +184,7 @@
elif pos_enc_layer_type == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
- logging.warning(
- "Using legacy_rel_pos and it will be deprecated in the future."
- )
+ logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@@ -230,9 +227,7 @@
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
- self.embed = torch.nn.Sequential(
- pos_enc_class(output_size, positional_dropout_rate)
- )
+ self.embed = torch.nn.Sequential(pos_enc_class(output_size, positional_dropout_rate))
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -278,9 +273,7 @@
output_size,
attention_dropout_rate,
)
- logging.warning(
- "Using legacy_rel_selfattn and it will be deprecated in the future."
- )
+ logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -329,11 +322,11 @@
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- channel_size: torch.Tensor,
- prev_states: torch.Tensor = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ channel_size: torch.Tensor,
+ prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
@@ -347,9 +340,9 @@
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -369,8 +362,7 @@
t_leng = xs_pad.size(1)
d_dim = xs_pad.size(2)
xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
- # pdb.set_trace()
- if (channel_size < 8):
+ if channel_size < 8:
repeat_num = math.ceil(8 / channel_size)
xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
xs_pad = self.conv1(xs_pad)
@@ -388,10 +380,10 @@
return xs_pad, olens, None
def forward_hidden(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
@@ -405,9 +397,9 @@
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -433,4 +425,4 @@
self.hidden_feature = self.after_norm(hidden_feature)
olens = masks.squeeze(1).sum(1)
- return xs_pad, olens, None
\ No newline at end of file
+ return xs_pad, olens, None
--
Gitblit v1.9.1