yhliang
2023-05-11 1d1ef01b4e23630a99a3be7e9d1dce9550a793e9
funasr/modules/nets_utils.py
@@ -485,14 +485,39 @@
        new_k = k.replace(old_prefix, new_prefix)
        state_dict[new_k] = v
class Swish(torch.nn.Module):
    """Construct an Swish object."""
    """Swish activation definition.
    def forward(self, x):
        """Return Swich activation function."""
        return x * torch.sigmoid(x)
    Swish(x) = (beta * x) * sigmoid(x)
                 where beta = 1 defines standard Swish activation.
    References:
        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
        E-swish variant: https://arxiv.org/abs/1801.07145.
    Args:
        beta: Beta parameter for E-Swish.
                (beta >= 1. If beta < 1, use standard Swish).
        use_builtin: Whether to use PyTorch function if available.
    """
    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
        super().__init__()
        self.beta = beta
        if beta > 1:
            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
        else:
            if use_builtin:
                self.swish = torch.nn.SiLU()
            else:
                self.swish = lambda x: x * torch.sigmoid(x)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward computation."""
        return self.swish(x)
def get_activation(act):
    """Return activation function."""