From 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 11:27:39 +0800
Subject: [PATCH] Dev gzf exp (#1665)
---
funasr/models/transformer/encoder.py | 81 +++++++++++++++++++---------------------
1 files changed, 39 insertions(+), 42 deletions(-)
diff --git a/funasr/models/transformer/encoder.py b/funasr/models/transformer/encoder.py
index 1f14867..a6a85ae 100644
--- a/funasr/models/transformer/encoder.py
+++ b/funasr/models/transformer/encoder.py
@@ -30,6 +30,7 @@
from funasr.register import tables
+
class EncoderLayer(nn.Module):
"""Encoder layer module.
@@ -53,14 +54,14 @@
"""
def __init__(
- self,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
@@ -118,9 +119,7 @@
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x_q, x, x, mask)
- )
+ x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
@@ -135,6 +134,7 @@
x = torch.cat([cache, x], dim=1)
return x, mask
+
@tables.register("encoder_classes", "TransformerEncoder")
class TransformerEncoder(nn.Module):
@@ -163,24 +163,24 @@
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
):
super().__init__()
self._output_size = output_size
@@ -243,9 +243,7 @@
num_blocks,
lambda lnum: EncoderLayer(
output_size,
- MultiHeadedAttention(
- attention_heads, output_size, attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
@@ -265,11 +263,11 @@
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
@@ -285,10 +283,10 @@
if self.embed is None:
xs_pad = xs_pad
elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -329,4 +327,3 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
-
--
Gitblit v1.9.1