From fa25b637b0d257186a8399eb1c530a91f4252702 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期五, 14 四月 2023 15:44:50 +0800
Subject: [PATCH] remove some functions

---
 funasr/models/e2e_transducer_unified.py              |    2 
 funasr/models/encoder/conformer_encoder.py           |   11 --
 funasr/models/rnnt_predictor/abs_decoder.py          |    0 
 funasr/models/rnnt_predictor/stateless_decoder.py    |    2 
 funasr/models/e2e_transducer.py                      |    2 
 funasr/models/joint_network.py                       |    5 
 funasr/models/rnnt_predictor/__init__.py             |    0 
 /dev/null                                            |  170 ------------------------------------------
 funasr/modules/e2e_asr_common.py                     |    2 
 funasr/modules/beam_search/beam_search_transducer.py |    2 
 funasr/models/rnnt_predictor/rnn_decoder.py          |    2 
 funasr/modules/repeat.py                             |    3 
 funasr/tasks/asr_transducer.py                       |    6 
 13 files changed, 13 insertions(+), 194 deletions(-)

diff --git a/funasr/models/e2e_transducer.py b/funasr/models/e2e_transducer.py
index 8630aec..460a6d7 100644
--- a/funasr/models/e2e_transducer.py
+++ b/funasr/models/e2e_transducer.py
@@ -10,7 +10,7 @@
 
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
 from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
 from funasr.models.joint_network import JointNetwork
diff --git a/funasr/models/e2e_transducer_unified.py b/funasr/models/e2e_transducer_unified.py
index 124bc09..f79ba57 100644
--- a/funasr/models/e2e_transducer_unified.py
+++ b/funasr/models/e2e_transducer_unified.py
@@ -10,7 +10,7 @@
 
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
 from funasr.models.joint_network import JointNetwork
 from funasr.modules.nets_utils import get_transducer_task_io
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index c837cf5..b7b552c 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -30,7 +30,6 @@
     StreamingRelPositionalEncoding,
 )
 from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.normalization import get_normalization
 from funasr.modules.multi_layer_conv import Conv1dLinear
 from funasr.modules.multi_layer_conv import MultiLayeredConv1d
 from funasr.modules.nets_utils import get_activation
@@ -940,7 +939,6 @@
         default_chunk_size: int = 16,
         jitter_range: int = 4,
         subsampling_factor: int = 1,
-        **activation_parameters,
     ) -> None:
         """Construct an Encoder object."""
         super().__init__()
@@ -961,7 +959,7 @@
         )
 
         activation = get_activation(
-            activation_type, **activation_parameters
+            activation_type
        )        
 
         pos_wise_args = (
@@ -991,9 +989,6 @@
             simplified_att_score,
         )
 
-        norm_class, norm_args = get_normalization(
-            norm_type,
-        )
 
         fn_modules = []
         for _ in range(num_blocks):
@@ -1003,8 +998,6 @@
                 PositionwiseFeedForward(*pos_wise_args),
                 PositionwiseFeedForward(*pos_wise_args),
                 CausalConvolution(*conv_mod_args),
-                norm_class=norm_class,
-                norm_args=norm_args,
                 dropout_rate=dropout_rate,
             )
             fn_modules.append(module)        
@@ -1012,8 +1005,6 @@
         self.encoders = MultiBlocks(
             [fn() for fn in fn_modules],
             output_size,
-            norm_class=norm_class,
-            norm_args=norm_args,
         )
 
         self.output_size = output_size
diff --git a/funasr/models/joint_network.py b/funasr/models/joint_network.py
index 5cabdb4..ed827c4 100644
--- a/funasr/models/joint_network.py
+++ b/funasr/models/joint_network.py
@@ -2,7 +2,7 @@
 
 import torch
 
-from funasr.modules.activation import get_activation
+from funasr.modules.nets_utils import get_activation
 
 
 class JointNetwork(torch.nn.Module):
@@ -25,7 +25,6 @@
         decoder_size: int,
         joint_space_size: int = 256,
         joint_activation_type: str = "tanh",
-        **activation_parameters,
     ) -> None:
         """Construct a JointNetwork object."""
         super().__init__()
@@ -36,7 +35,7 @@
         self.lin_out = torch.nn.Linear(joint_space_size, output_size)
 
         self.joint_activation = get_activation(
-            joint_activation_type, **activation_parameters
+            joint_activation_type
         )
 
     def forward(
diff --git a/funasr/models/rnnt_decoder/__init__.py b/funasr/models/rnnt_predictor/__init__.py
similarity index 100%
rename from funasr/models/rnnt_decoder/__init__.py
rename to funasr/models/rnnt_predictor/__init__.py
diff --git a/funasr/models/rnnt_decoder/abs_decoder.py b/funasr/models/rnnt_predictor/abs_decoder.py
similarity index 100%
rename from funasr/models/rnnt_decoder/abs_decoder.py
rename to funasr/models/rnnt_predictor/abs_decoder.py
diff --git a/funasr/models/rnnt_decoder/rnn_decoder.py b/funasr/models/rnnt_predictor/rnn_decoder.py
similarity index 98%
rename from funasr/models/rnnt_decoder/rnn_decoder.py
rename to funasr/models/rnnt_predictor/rnn_decoder.py
index c4e7951..0df6fc7 100644
--- a/funasr/models/rnnt_decoder/rnn_decoder.py
+++ b/funasr/models/rnnt_predictor/rnn_decoder.py
@@ -6,7 +6,7 @@
 from typeguard import check_argument_types
 
 from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.specaug.specaug import SpecAug
 
 class RNNDecoder(AbsDecoder):
diff --git a/funasr/models/rnnt_decoder/stateless_decoder.py b/funasr/models/rnnt_predictor/stateless_decoder.py
similarity index 98%
rename from funasr/models/rnnt_decoder/stateless_decoder.py
rename to funasr/models/rnnt_predictor/stateless_decoder.py
index a2e1fc1..70cd877 100644
--- a/funasr/models/rnnt_decoder/stateless_decoder.py
+++ b/funasr/models/rnnt_predictor/stateless_decoder.py
@@ -6,7 +6,7 @@
 from typeguard import check_argument_types
 
 from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.specaug.specaug import SpecAug
 
 class StatelessDecoder(AbsDecoder):
diff --git a/funasr/modules/activation.py b/funasr/modules/activation.py
deleted file mode 100644
index 82cda12..0000000
--- a/funasr/modules/activation.py
+++ /dev/null
@@ -1,213 +0,0 @@
-"""Activation functions for Transducer."""
-
-import torch
-from packaging.version import parse as V
-
-
-def get_activation(
-    activation_type: str,
-    ftswish_threshold: float = -0.2,
-    ftswish_mean_shift: float = 0.0,
-    hardtanh_min_val: int = -1.0,
-    hardtanh_max_val: int = 1.0,
-    leakyrelu_neg_slope: float = 0.01,
-    smish_alpha: float = 1.0,
-    smish_beta: float = 1.0,
-    softplus_beta: float = 1.0,
-    softplus_threshold: int = 20,
-    swish_beta: float = 1.0,
-) -> torch.nn.Module:
-    """Return activation function.
-
-    Args:
-        activation_type: Activation function type.
-        ftswish_threshold: Threshold value for FTSwish activation formulation.
-        ftswish_mean_shift: Mean shifting value for FTSwish activation formulation.
-        hardtanh_min_val: Minimum value of the linear region range for HardTanh.
-        hardtanh_max_val: Maximum value of the linear region range for HardTanh.
-        leakyrelu_neg_slope: Negative slope value for LeakyReLU activation formulation.
-        smish_alpha: Alpha value for Smish activation fomulation.
-        smish_beta: Beta value for Smish activation formulation.
-        softplus_beta: Beta value for softplus activation formulation in Mish.
-        softplus_threshold: Values above this revert to a linear function in Mish.
-        swish_beta: Beta value for Swish variant formulation.
-
-    Returns:
-        : Activation function.
-
-    """
-    torch_version = V(torch.__version__)
-
-    activations = {
-        "ftswish": (
-            FTSwish,
-            {"threshold": ftswish_threshold, "mean_shift": ftswish_mean_shift},
-        ),
-        "hardtanh": (
-            torch.nn.Hardtanh,
-            {"min_val": hardtanh_min_val, "max_val": hardtanh_max_val},
-        ),
-        "leaky_relu": (torch.nn.LeakyReLU, {"negative_slope": leakyrelu_neg_slope}),
-        "mish": (
-            Mish,
-            {
-                "softplus_beta": softplus_beta,
-                "softplus_threshold": softplus_threshold,
-                "use_builtin": torch_version >= V("1.9"),
-            },
-        ),
-        "relu": (torch.nn.ReLU, {}),
-        "selu": (torch.nn.SELU, {}),
-        "smish": (Smish, {"alpha": smish_alpha, "beta": smish_beta}),
-        "swish": (
-            Swish,
-            {"beta": swish_beta, "use_builtin": torch_version >= V("1.8")},
-        ),
-        "tanh": (torch.nn.Tanh, {}),
-        "identity": (torch.nn.Identity, {}),
-    }
-
-    act_func, act_args = activations[activation_type]
-
-    return act_func(**act_args)
-
-
-class FTSwish(torch.nn.Module):
-    """Flatten-T Swish activation definition.
-
-    FTSwish(x) = x * sigmoid(x) + threshold
-                  where FTSwish(x) < 0 = threshold
-
-    Reference: https://arxiv.org/abs/1812.06247
-
-    Args:
-        threshold: Threshold value for FTSwish activation formulation. (threshold < 0)
-        mean_shift: Mean shifting value for FTSwish activation formulation.
-                       (applied only if != 0, disabled by default)
-
-    """
-
-    def __init__(self, threshold: float = -0.2, mean_shift: float = 0) -> None:
-        super().__init__()
-
-        assert threshold < 0, "FTSwish threshold parameter should be < 0."
-
-        self.threshold = threshold
-        self.mean_shift = mean_shift
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Forward computation."""
-        x = (x * torch.sigmoid(x)) + self.threshold
-        x = torch.where(x >= 0, x, torch.tensor([self.threshold], device=x.device))
-
-        if self.mean_shift != 0:
-            x.sub_(self.mean_shift)
-
-        return x
-
-
-class Mish(torch.nn.Module):
-    """Mish activation definition.
-
-    Mish(x) = x * tanh(softplus(x))
-
-    Reference: https://arxiv.org/abs/1908.08681.
-
-    Args:
-        softplus_beta: Beta value for softplus activation formulation.
-                         (Usually 0 > softplus_beta >= 2)
-        softplus_threshold: Values above this revert to a linear function.
-                         (Usually 10 > softplus_threshold >= 20)
-        use_builtin: Whether to use PyTorch activation function if available.
-
-    """
-
-    def __init__(
-        self,
-        softplus_beta: float = 1.0,
-        softplus_threshold: int = 20,
-        use_builtin: bool = False,
-    ) -> None:
-        super().__init__()
-
-        if use_builtin:
-            self.mish = torch.nn.Mish()
-        else:
-            self.tanh = torch.nn.Tanh()
-            self.softplus = torch.nn.Softplus(
-                beta=softplus_beta, threshold=softplus_threshold
-            )
-
-            self.mish = lambda x: x * self.tanh(self.softplus(x))
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Forward computation."""
-        return self.mish(x)
-
-
-class Smish(torch.nn.Module):
-    """Smish activation definition.
-
-    Smish(x) = (alpha * x) * tanh(log(1 + sigmoid(beta * x)))
-                 where alpha > 0 and beta > 0
-
-    Reference: https://www.mdpi.com/2079-9292/11/4/540/htm.
-
-    Args:
-        alpha: Alpha value for Smish activation fomulation.
-                 (Usually, alpha = 1. If alpha <= 0, set value to 1).
-        beta: Beta value for Smish activation formulation.
-                (Usually, beta = 1. If beta <= 0, set value to 1).
-
-    """
-
-    def __init__(self, alpha: float = 1.0, beta: float = 1.0) -> None:
-        super().__init__()
-
-        self.tanh = torch.nn.Tanh()
-
-        self.alpha = alpha if alpha > 0 else 1
-        self.beta = beta if beta > 0 else 1
-
-        self.smish = lambda x: (self.alpha * x) * self.tanh(
-            torch.log(1 + torch.sigmoid((self.beta * x)))
-        )
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Forward computation."""
-        return self.smish(x)
-
-
-class Swish(torch.nn.Module):
-    """Swish activation definition.
-
-    Swish(x) = (beta * x) * sigmoid(x)
-                 where beta = 1 defines standard Swish activation.
-
-    References:
-        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
-        E-swish variant: https://arxiv.org/abs/1801.07145.
-
-    Args:
-        beta: Beta parameter for E-Swish.
-                (beta >= 1. If beta < 1, use standard Swish).
-        use_builtin: Whether to use PyTorch function if available.
-
-    """
-
-    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
-        super().__init__()
-
-        self.beta = beta
-
-        if beta > 1:
-            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
-        else:
-            if use_builtin:
-                self.swish = torch.nn.SiLU()
-            else:
-                self.swish = lambda x: x * torch.sigmoid(x)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Forward computation."""
-        return self.swish(x)
diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
index eaf5627..49cce92 100644
--- a/funasr/modules/beam_search/beam_search_transducer.py
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -6,7 +6,7 @@
 import numpy as np
 import torch
 
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.joint_network import JointNetwork
 
 
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index 9b5039c..3746036 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -18,7 +18,7 @@
 import torch
 
 from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.joint_network import JointNetwork
 
 def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
diff --git a/funasr/modules/normalization.py b/funasr/modules/normalization.py
deleted file mode 100644
index ae35fd4..0000000
--- a/funasr/modules/normalization.py
+++ /dev/null
@@ -1,170 +0,0 @@
-"""Normalization modules for X-former blocks."""
-
-from typing import Dict, Optional, Tuple
-
-import torch
-
-
-def get_normalization(
-    normalization_type: str,
-    eps: Optional[float] = None,
-    partial: Optional[float] = None,
-) -> Tuple[torch.nn.Module, Dict]:
-    """Get normalization module and arguments given parameters.
-
-    Args:
-        normalization_type: Normalization module type.
-        eps: Value added to the denominator.
-        partial: Value defining the part of the input used for RMS stats (RMSNorm).
-
-    Return:
-        : Normalization module class
-        : Normalization module arguments
-
-    """
-    norm = {
-        "basic_norm": (
-            BasicNorm,
-            {"eps": eps if eps is not None else 0.25},
-        ),
-        "layer_norm": (torch.nn.LayerNorm, {"eps": eps if eps is not None else 1e-12}),
-        "rms_norm": (
-            RMSNorm,
-            {
-                "eps": eps if eps is not None else 1e-05,
-                "partial": partial if partial is not None else -1.0,
-            },
-        ),
-        "scale_norm": (
-            ScaleNorm,
-            {"eps": eps if eps is not None else 1e-05},
-        ),
-    }
-
-    return norm[normalization_type]
-
-
-class BasicNorm(torch.nn.Module):
-    """BasicNorm module definition.
-
-    Reference: https://github.com/k2-fsa/icefall/pull/288
-
-    Args:
-        normalized_shape: Expected size.
-        eps: Value added to the denominator for numerical stability.
-
-    """
-
-    def __init__(
-        self,
-        normalized_shape: int,
-        eps: float = 0.25,
-    ) -> None:
-        """Construct a BasicNorm object."""
-        super().__init__()
-
-        self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach())
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Compute basic normalization.
-
-        Args:
-            x: Input sequences. (B, T, D_hidden)
-
-        Returns:
-            : Output sequences. (B, T, D_hidden)
-
-        """
-        scales = (torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps.exp()) ** -0.5
-
-        return x * scales
-
-
-class RMSNorm(torch.nn.Module):
-    """RMSNorm module definition.
-
-    Reference: https://arxiv.org/pdf/1910.07467.pdf
-
-    Args:
-        normalized_shape: Expected size.
-        eps: Value added to the denominator for numerical stability.
-        partial: Value defining the part of the input used for RMS stats.
-
-    """
-
-    def __init__(
-        self,
-        normalized_shape: int,
-        eps: float = 1e-5,
-        partial: float = 0.0,
-    ) -> None:
-        """Construct a RMSNorm object."""
-        super().__init__()
-
-        self.normalized_shape = normalized_shape
-
-        self.partial = True if 0 < partial < 1 else False
-        self.p = partial
-        self.eps = eps
-
-        self.scale = torch.nn.Parameter(torch.ones(normalized_shape))
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Compute RMS normalization.
-
-        Args:
-            x: Input sequences. (B, T, D_hidden)
-
-        Returns:
-            x: Output sequences. (B, T, D_hidden)
-
-        """
-        if self.partial:
-            partial_size = int(self.normalized_shape * self.p)
-            partial_x, _ = torch.split(
-                x, [partial_size, self.normalized_shape - partial_size], dim=-1
-            )
-
-            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
-            d_x = partial_size
-        else:
-            norm_x = x.norm(2, dim=-1, keepdim=True)
-            d_x = self.normalized_shape
-
-        rms_x = norm_x * d_x ** (-1.0 / 2)
-        x = self.scale * (x / (rms_x + self.eps))
-
-        return x
-
-
-class ScaleNorm(torch.nn.Module):
-    """ScaleNorm module definition.
-
-    Reference: https://arxiv.org/pdf/1910.05895.pdf
-
-    Args:
-        normalized_shape: Expected size.
-        eps: Value added to the denominator for numerical stability.
-
-    """
-
-    def __init__(self, normalized_shape: int, eps: float = 1e-5) -> None:
-        """Construct a ScaleNorm object."""
-        super().__init__()
-
-        self.eps = eps
-        self.scale = torch.nn.Parameter(torch.tensor(normalized_shape**0.5))
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Compute scale normalization.
-
-        Args:
-            x: Input sequences. (B, T, D_hidden)
-
-        Returns:
-            : Output sequences. (B, T, D_hidden)
-
-        """
-        norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
-
-        return x * norm
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index 7241dd9..2b2dac8 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -49,13 +49,12 @@
         block_list: List[torch.nn.Module],
         output_size: int,
         norm_class: torch.nn.Module = torch.nn.LayerNorm,
-        norm_args: Optional[Dict] = None,
     ) -> None:
         """Construct a MultiBlocks object."""
         super().__init__()
 
         self.blocks = torch.nn.ModuleList(block_list)
-        self.norm_blocks = norm_class(output_size, **norm_args)
+        self.norm_blocks = norm_class(output_size)
 
         self.num_blocks = len(block_list)
 
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index bb1f996..99b3d0c 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,9 +21,9 @@
     LightweightConvolutionTransformerDecoder,
     TransformerDecoder,
 )
-from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
-from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
-from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
+from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder
 from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder
 from funasr.models.e2e_transducer import TransducerModel
 from funasr.models.e2e_transducer_unified import UnifiedTransducerModel

--
Gitblit v1.9.1