From fa25b637b0d257186a8399eb1c530a91f4252702 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期五, 14 四月 2023 15:44:50 +0800
Subject: [PATCH] remove some functions
---
funasr/models/e2e_transducer_unified.py | 2
funasr/models/encoder/conformer_encoder.py | 11 --
funasr/models/rnnt_predictor/abs_decoder.py | 0
funasr/models/rnnt_predictor/stateless_decoder.py | 2
funasr/models/e2e_transducer.py | 2
funasr/models/joint_network.py | 5
funasr/models/rnnt_predictor/__init__.py | 0
/dev/null | 170 ------------------------------------------
funasr/modules/e2e_asr_common.py | 2
funasr/modules/beam_search/beam_search_transducer.py | 2
funasr/models/rnnt_predictor/rnn_decoder.py | 2
funasr/modules/repeat.py | 3
funasr/tasks/asr_transducer.py | 6
13 files changed, 13 insertions(+), 194 deletions(-)
diff --git a/funasr/models/e2e_transducer.py b/funasr/models/e2e_transducer.py
index 8630aec..460a6d7 100644
--- a/funasr/models/e2e_transducer.py
+++ b/funasr/models/e2e_transducer.py
@@ -10,7 +10,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
diff --git a/funasr/models/e2e_transducer_unified.py b/funasr/models/e2e_transducer_unified.py
index 124bc09..f79ba57 100644
--- a/funasr/models/e2e_transducer_unified.py
+++ b/funasr/models/e2e_transducer_unified.py
@@ -10,7 +10,7 @@
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index c837cf5..b7b552c 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -30,7 +30,6 @@
StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.normalization import get_normalization
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
@@ -940,7 +939,6 @@
default_chunk_size: int = 16,
jitter_range: int = 4,
subsampling_factor: int = 1,
- **activation_parameters,
) -> None:
"""Construct an Encoder object."""
super().__init__()
@@ -961,7 +959,7 @@
)
activation = get_activation(
- activation_type, **activation_parameters
+ activation_type
)
pos_wise_args = (
@@ -991,9 +989,6 @@
simplified_att_score,
)
- norm_class, norm_args = get_normalization(
- norm_type,
- )
fn_modules = []
for _ in range(num_blocks):
@@ -1003,8 +998,6 @@
PositionwiseFeedForward(*pos_wise_args),
PositionwiseFeedForward(*pos_wise_args),
CausalConvolution(*conv_mod_args),
- norm_class=norm_class,
- norm_args=norm_args,
dropout_rate=dropout_rate,
)
fn_modules.append(module)
@@ -1012,8 +1005,6 @@
self.encoders = MultiBlocks(
[fn() for fn in fn_modules],
output_size,
- norm_class=norm_class,
- norm_args=norm_args,
)
self.output_size = output_size
diff --git a/funasr/models/joint_network.py b/funasr/models/joint_network.py
index 5cabdb4..ed827c4 100644
--- a/funasr/models/joint_network.py
+++ b/funasr/models/joint_network.py
@@ -2,7 +2,7 @@
import torch
-from funasr.modules.activation import get_activation
+from funasr.modules.nets_utils import get_activation
class JointNetwork(torch.nn.Module):
@@ -25,7 +25,6 @@
decoder_size: int,
joint_space_size: int = 256,
joint_activation_type: str = "tanh",
- **activation_parameters,
) -> None:
"""Construct a JointNetwork object."""
super().__init__()
@@ -36,7 +35,7 @@
self.lin_out = torch.nn.Linear(joint_space_size, output_size)
self.joint_activation = get_activation(
- joint_activation_type, **activation_parameters
+ joint_activation_type
)
def forward(
diff --git a/funasr/models/rnnt_decoder/__init__.py b/funasr/models/rnnt_predictor/__init__.py
similarity index 100%
rename from funasr/models/rnnt_decoder/__init__.py
rename to funasr/models/rnnt_predictor/__init__.py
diff --git a/funasr/models/rnnt_decoder/abs_decoder.py b/funasr/models/rnnt_predictor/abs_decoder.py
similarity index 100%
rename from funasr/models/rnnt_decoder/abs_decoder.py
rename to funasr/models/rnnt_predictor/abs_decoder.py
diff --git a/funasr/models/rnnt_decoder/rnn_decoder.py b/funasr/models/rnnt_predictor/rnn_decoder.py
similarity index 98%
rename from funasr/models/rnnt_decoder/rnn_decoder.py
rename to funasr/models/rnnt_predictor/rnn_decoder.py
index c4e7951..0df6fc7 100644
--- a/funasr/models/rnnt_decoder/rnn_decoder.py
+++ b/funasr/models/rnnt_predictor/rnn_decoder.py
@@ -6,7 +6,7 @@
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class RNNDecoder(AbsDecoder):
diff --git a/funasr/models/rnnt_decoder/stateless_decoder.py b/funasr/models/rnnt_predictor/stateless_decoder.py
similarity index 98%
rename from funasr/models/rnnt_decoder/stateless_decoder.py
rename to funasr/models/rnnt_predictor/stateless_decoder.py
index a2e1fc1..70cd877 100644
--- a/funasr/models/rnnt_decoder/stateless_decoder.py
+++ b/funasr/models/rnnt_predictor/stateless_decoder.py
@@ -6,7 +6,7 @@
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class StatelessDecoder(AbsDecoder):
diff --git a/funasr/modules/activation.py b/funasr/modules/activation.py
deleted file mode 100644
index 82cda12..0000000
--- a/funasr/modules/activation.py
+++ /dev/null
@@ -1,213 +0,0 @@
-"""Activation functions for Transducer."""
-
-import torch
-from packaging.version import parse as V
-
-
-def get_activation(
- activation_type: str,
- ftswish_threshold: float = -0.2,
- ftswish_mean_shift: float = 0.0,
- hardtanh_min_val: int = -1.0,
- hardtanh_max_val: int = 1.0,
- leakyrelu_neg_slope: float = 0.01,
- smish_alpha: float = 1.0,
- smish_beta: float = 1.0,
- softplus_beta: float = 1.0,
- softplus_threshold: int = 20,
- swish_beta: float = 1.0,
-) -> torch.nn.Module:
- """Return activation function.
-
- Args:
- activation_type: Activation function type.
- ftswish_threshold: Threshold value for FTSwish activation formulation.
- ftswish_mean_shift: Mean shifting value for FTSwish activation formulation.
- hardtanh_min_val: Minimum value of the linear region range for HardTanh.
- hardtanh_max_val: Maximum value of the linear region range for HardTanh.
- leakyrelu_neg_slope: Negative slope value for LeakyReLU activation formulation.
- smish_alpha: Alpha value for Smish activation fomulation.
- smish_beta: Beta value for Smish activation formulation.
- softplus_beta: Beta value for softplus activation formulation in Mish.
- softplus_threshold: Values above this revert to a linear function in Mish.
- swish_beta: Beta value for Swish variant formulation.
-
- Returns:
- : Activation function.
-
- """
- torch_version = V(torch.__version__)
-
- activations = {
- "ftswish": (
- FTSwish,
- {"threshold": ftswish_threshold, "mean_shift": ftswish_mean_shift},
- ),
- "hardtanh": (
- torch.nn.Hardtanh,
- {"min_val": hardtanh_min_val, "max_val": hardtanh_max_val},
- ),
- "leaky_relu": (torch.nn.LeakyReLU, {"negative_slope": leakyrelu_neg_slope}),
- "mish": (
- Mish,
- {
- "softplus_beta": softplus_beta,
- "softplus_threshold": softplus_threshold,
- "use_builtin": torch_version >= V("1.9"),
- },
- ),
- "relu": (torch.nn.ReLU, {}),
- "selu": (torch.nn.SELU, {}),
- "smish": (Smish, {"alpha": smish_alpha, "beta": smish_beta}),
- "swish": (
- Swish,
- {"beta": swish_beta, "use_builtin": torch_version >= V("1.8")},
- ),
- "tanh": (torch.nn.Tanh, {}),
- "identity": (torch.nn.Identity, {}),
- }
-
- act_func, act_args = activations[activation_type]
-
- return act_func(**act_args)
-
-
-class FTSwish(torch.nn.Module):
- """Flatten-T Swish activation definition.
-
- FTSwish(x) = x * sigmoid(x) + threshold
- where FTSwish(x) < 0 = threshold
-
- Reference: https://arxiv.org/abs/1812.06247
-
- Args:
- threshold: Threshold value for FTSwish activation formulation. (threshold < 0)
- mean_shift: Mean shifting value for FTSwish activation formulation.
- (applied only if != 0, disabled by default)
-
- """
-
- def __init__(self, threshold: float = -0.2, mean_shift: float = 0) -> None:
- super().__init__()
-
- assert threshold < 0, "FTSwish threshold parameter should be < 0."
-
- self.threshold = threshold
- self.mean_shift = mean_shift
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward computation."""
- x = (x * torch.sigmoid(x)) + self.threshold
- x = torch.where(x >= 0, x, torch.tensor([self.threshold], device=x.device))
-
- if self.mean_shift != 0:
- x.sub_(self.mean_shift)
-
- return x
-
-
-class Mish(torch.nn.Module):
- """Mish activation definition.
-
- Mish(x) = x * tanh(softplus(x))
-
- Reference: https://arxiv.org/abs/1908.08681.
-
- Args:
- softplus_beta: Beta value for softplus activation formulation.
- (Usually 0 > softplus_beta >= 2)
- softplus_threshold: Values above this revert to a linear function.
- (Usually 10 > softplus_threshold >= 20)
- use_builtin: Whether to use PyTorch activation function if available.
-
- """
-
- def __init__(
- self,
- softplus_beta: float = 1.0,
- softplus_threshold: int = 20,
- use_builtin: bool = False,
- ) -> None:
- super().__init__()
-
- if use_builtin:
- self.mish = torch.nn.Mish()
- else:
- self.tanh = torch.nn.Tanh()
- self.softplus = torch.nn.Softplus(
- beta=softplus_beta, threshold=softplus_threshold
- )
-
- self.mish = lambda x: x * self.tanh(self.softplus(x))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward computation."""
- return self.mish(x)
-
-
-class Smish(torch.nn.Module):
- """Smish activation definition.
-
- Smish(x) = (alpha * x) * tanh(log(1 + sigmoid(beta * x)))
- where alpha > 0 and beta > 0
-
- Reference: https://www.mdpi.com/2079-9292/11/4/540/htm.
-
- Args:
- alpha: Alpha value for Smish activation fomulation.
- (Usually, alpha = 1. If alpha <= 0, set value to 1).
- beta: Beta value for Smish activation formulation.
- (Usually, beta = 1. If beta <= 0, set value to 1).
-
- """
-
- def __init__(self, alpha: float = 1.0, beta: float = 1.0) -> None:
- super().__init__()
-
- self.tanh = torch.nn.Tanh()
-
- self.alpha = alpha if alpha > 0 else 1
- self.beta = beta if beta > 0 else 1
-
- self.smish = lambda x: (self.alpha * x) * self.tanh(
- torch.log(1 + torch.sigmoid((self.beta * x)))
- )
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward computation."""
- return self.smish(x)
-
-
-class Swish(torch.nn.Module):
- """Swish activation definition.
-
- Swish(x) = (beta * x) * sigmoid(x)
- where beta = 1 defines standard Swish activation.
-
- References:
- https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
- E-swish variant: https://arxiv.org/abs/1801.07145.
-
- Args:
- beta: Beta parameter for E-Swish.
- (beta >= 1. If beta < 1, use standard Swish).
- use_builtin: Whether to use PyTorch function if available.
-
- """
-
- def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
- super().__init__()
-
- self.beta = beta
-
- if beta > 1:
- self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
- else:
- if use_builtin:
- self.swish = torch.nn.SiLU()
- else:
- self.swish = lambda x: x * torch.sigmoid(x)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward computation."""
- return self.swish(x)
diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
index eaf5627..49cce92 100644
--- a/funasr/modules/beam_search/beam_search_transducer.py
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -6,7 +6,7 @@
import numpy as np
import torch
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.joint_network import JointNetwork
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index 9b5039c..3746036 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -18,7 +18,7 @@
import torch
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
diff --git a/funasr/modules/normalization.py b/funasr/modules/normalization.py
deleted file mode 100644
index ae35fd4..0000000
--- a/funasr/modules/normalization.py
+++ /dev/null
@@ -1,170 +0,0 @@
-"""Normalization modules for X-former blocks."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-def get_normalization(
- normalization_type: str,
- eps: Optional[float] = None,
- partial: Optional[float] = None,
-) -> Tuple[torch.nn.Module, Dict]:
- """Get normalization module and arguments given parameters.
-
- Args:
- normalization_type: Normalization module type.
- eps: Value added to the denominator.
- partial: Value defining the part of the input used for RMS stats (RMSNorm).
-
- Return:
- : Normalization module class
- : Normalization module arguments
-
- """
- norm = {
- "basic_norm": (
- BasicNorm,
- {"eps": eps if eps is not None else 0.25},
- ),
- "layer_norm": (torch.nn.LayerNorm, {"eps": eps if eps is not None else 1e-12}),
- "rms_norm": (
- RMSNorm,
- {
- "eps": eps if eps is not None else 1e-05,
- "partial": partial if partial is not None else -1.0,
- },
- ),
- "scale_norm": (
- ScaleNorm,
- {"eps": eps if eps is not None else 1e-05},
- ),
- }
-
- return norm[normalization_type]
-
-
-class BasicNorm(torch.nn.Module):
- """BasicNorm module definition.
-
- Reference: https://github.com/k2-fsa/icefall/pull/288
-
- Args:
- normalized_shape: Expected size.
- eps: Value added to the denominator for numerical stability.
-
- """
-
- def __init__(
- self,
- normalized_shape: int,
- eps: float = 0.25,
- ) -> None:
- """Construct a BasicNorm object."""
- super().__init__()
-
- self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach())
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Compute basic normalization.
-
- Args:
- x: Input sequences. (B, T, D_hidden)
-
- Returns:
- : Output sequences. (B, T, D_hidden)
-
- """
- scales = (torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps.exp()) ** -0.5
-
- return x * scales
-
-
-class RMSNorm(torch.nn.Module):
- """RMSNorm module definition.
-
- Reference: https://arxiv.org/pdf/1910.07467.pdf
-
- Args:
- normalized_shape: Expected size.
- eps: Value added to the denominator for numerical stability.
- partial: Value defining the part of the input used for RMS stats.
-
- """
-
- def __init__(
- self,
- normalized_shape: int,
- eps: float = 1e-5,
- partial: float = 0.0,
- ) -> None:
- """Construct a RMSNorm object."""
- super().__init__()
-
- self.normalized_shape = normalized_shape
-
- self.partial = True if 0 < partial < 1 else False
- self.p = partial
- self.eps = eps
-
- self.scale = torch.nn.Parameter(torch.ones(normalized_shape))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Compute RMS normalization.
-
- Args:
- x: Input sequences. (B, T, D_hidden)
-
- Returns:
- x: Output sequences. (B, T, D_hidden)
-
- """
- if self.partial:
- partial_size = int(self.normalized_shape * self.p)
- partial_x, _ = torch.split(
- x, [partial_size, self.normalized_shape - partial_size], dim=-1
- )
-
- norm_x = partial_x.norm(2, dim=-1, keepdim=True)
- d_x = partial_size
- else:
- norm_x = x.norm(2, dim=-1, keepdim=True)
- d_x = self.normalized_shape
-
- rms_x = norm_x * d_x ** (-1.0 / 2)
- x = self.scale * (x / (rms_x + self.eps))
-
- return x
-
-
-class ScaleNorm(torch.nn.Module):
- """ScaleNorm module definition.
-
- Reference: https://arxiv.org/pdf/1910.05895.pdf
-
- Args:
- normalized_shape: Expected size.
- eps: Value added to the denominator for numerical stability.
-
- """
-
- def __init__(self, normalized_shape: int, eps: float = 1e-5) -> None:
- """Construct a ScaleNorm object."""
- super().__init__()
-
- self.eps = eps
- self.scale = torch.nn.Parameter(torch.tensor(normalized_shape**0.5))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Compute scale normalization.
-
- Args:
- x: Input sequences. (B, T, D_hidden)
-
- Returns:
- : Output sequences. (B, T, D_hidden)
-
- """
- norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
-
- return x * norm
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index 7241dd9..2b2dac8 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -49,13 +49,12 @@
block_list: List[torch.nn.Module],
output_size: int,
norm_class: torch.nn.Module = torch.nn.LayerNorm,
- norm_args: Optional[Dict] = None,
) -> None:
"""Construct a MultiBlocks object."""
super().__init__()
self.blocks = torch.nn.ModuleList(block_list)
- self.norm_blocks = norm_class(output_size, **norm_args)
+ self.norm_blocks = norm_class(output_size)
self.num_blocks = len(block_list)
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index bb1f996..99b3d0c 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,9 +21,9 @@
LightweightConvolutionTransformerDecoder,
TransformerDecoder,
)
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
-from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
-from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder
from funasr.models.e2e_transducer import TransducerModel
from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
--
Gitblit v1.9.1