| | |
| | | n_text_layer: int |
| | | |
| | | |
| | | # class LayerNorm(nn.LayerNorm): |
| | | # def forward(self, x: Tensor) -> Tensor: |
| | | # return super().forward(x.float()).type(x.dtype) |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return super().forward(x.float()).type(x.dtype) |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, input): |
| | | output = F.layer_norm( |
| | | input.float(), |
| | | self.normalized_shape, |
| | | self.weight.float() if self.weight is not None else None, |
| | | self.bias.float() if self.bias is not None else None, |
| | | self.eps, |
| | | ) |
| | | return output.type_as(input) |
| | | |
| | | |
| | | class Linear(nn.Linear): |