zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/data2vec/wav2vec2.py
@@ -23,7 +23,7 @@
            dropout: float = 0.0,
            mode: str = "default",
            conv_bias: bool = False,
            in_d: int = 1
        in_d: int = 1,
    ):
        super().__init__()
@@ -185,9 +185,7 @@
                    ]
                )
            self.pos_conv = make_conv_block(
                self.embedding_dim, k, conv_pos_groups, num_layers
            )
            self.pos_conv = make_conv_block(self.embedding_dim, k, conv_pos_groups, num_layers)
        else:
            self.pos_conv = make_conv_pos(
@@ -206,9 +204,7 @@
        self.layer_norm_first = layer_norm_first
        self.layerdrop = encoder_layerdrop
        self.max_positions = max_positions
        self.layers = nn.ModuleList(
            [self.build_encoder_layer() for _ in range(encoder_layers)]
        )
        self.layers = nn.ModuleList([self.build_encoder_layer() for _ in range(encoder_layers)])
        self.layer_norm = torch.nn.LayerNorm(self.embedding_dim)
        self.apply(utils.init_bert_params)
@@ -240,9 +236,7 @@
            x = self.layer_norm(x)
        # pad to the sequence length dimension
        x, pad_length = utils.pad_to_multiple(
            x, self.required_seq_len_multiple, dim=-2, value=0
        )
        x, pad_length = utils.pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
        if pad_length > 0 and padding_mask is None:
            padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
            padding_mask[:, -pad_length:] = True