| | |
| | | n_text_layer: int |
| | | |
| | | |
| | | # class LayerNorm(nn.LayerNorm): |
| | | # def forward(self, x: Tensor) -> Tensor: |
| | | # return super().forward(x.float()).type(x.dtype) |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return super().forward(x.float()).type(x.dtype) |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, input): |
| | | output = F.layer_norm( |
| | | input.float(), |
| | | self.normalized_shape, |
| | | self.weight.float() if self.weight is not None else None, |
| | | self.bias.float() if self.bias is not None else None, |
| | | self.eps, |
| | | ) |
| | | return output.type_as(input) |
| | | |
| | | |
| | | class Linear(nn.Linear): |
| | |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | else: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) |
| | | min_value = -float( |
| | | "inf" |
| | | ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) |
| | | qk = qk.masked_fill(mask, min_value) |
| | | |
| | | qk = qk.float() |