游雁
2024-06-24 1596f6f414f6f41da66506debb1dff19fffeb3ec
funasr/models/sense_voice/whisper_lib/model.py
@@ -27,9 +27,24 @@
    n_text_layer: int
# class LayerNorm(nn.LayerNorm):
#     def forward(self, x: Tensor) -> Tensor:
#         return super().forward(x.float()).type(x.dtype)
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, input):
        output = F.layer_norm(
            input.float(),
            self.normalized_shape,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        )
        return output.type_as(input)
class Linear(nn.Linear):
@@ -112,7 +127,9 @@
                qk = qk + mask[:n_ctx, :n_ctx]
            else:
                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
                min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                min_value = -float(
                    "inf"
                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                qk = qk.masked_fill(mask, min_value)
        qk = qk.float()