aky15
2023-04-14 fa25b637b0d257186a8399eb1c530a91f4252702
funasr/models/joint_network.py
@@ -2,7 +2,7 @@
import torch
from funasr.modules.activation import get_activation
from funasr.modules.nets_utils import get_activation
class JointNetwork(torch.nn.Module):
@@ -25,7 +25,6 @@
        decoder_size: int,
        joint_space_size: int = 256,
        joint_activation_type: str = "tanh",
        **activation_parameters,
    ) -> None:
        """Construct a JointNetwork object."""
        super().__init__()
@@ -36,7 +35,7 @@
        self.lin_out = torch.nn.Linear(joint_space_size, output_size)
        self.joint_activation = get_activation(
            joint_activation_type, **activation_parameters
            joint_activation_type
        )
    def forward(