liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/sond/sv_decoder.py
@@ -5,19 +5,23 @@
class DenseDecoder(AbsDecoder):
    def __init__(
            self,
            vocab_size,
            encoder_output_size,
            num_nodes_resnet1: int = 256,
            num_nodes_last_layer: int = 256,
            batchnorm_momentum: float = 0.5,
        self,
        vocab_size,
        encoder_output_size,
        num_nodes_resnet1: int = 256,
        num_nodes_last_layer: int = 256,
        batchnorm_momentum: float = 0.5,
    ):
        super(DenseDecoder, self).__init__()
        self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1)
        self.resnet1_bn = torch.nn.BatchNorm1d(num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum)
        self.resnet1_bn = torch.nn.BatchNorm1d(
            num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum
        )
        self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer)
        self.resnet2_bn = torch.nn.BatchNorm1d(num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum)
        self.resnet2_bn = torch.nn.BatchNorm1d(
            num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum
        )
        self.output_dense = torch.nn.Linear(num_nodes_last_layer, vocab_size, bias=False)