From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sense_voice/whisper_lib/model.py |   23 ++++++++++++++++++++---
 1 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 40939df..3d0d6a8 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -27,9 +27,24 @@
     n_text_layer: int
 
 
+# class LayerNorm(nn.LayerNorm):
+#     def forward(self, x: Tensor) -> Tensor:
+#         return super().forward(x.float()).type(x.dtype)
+
+
 class LayerNorm(nn.LayerNorm):
-    def forward(self, x: Tensor) -> Tensor:
-        return super().forward(x.float()).type(x.dtype)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, input):
+        output = F.layer_norm(
+            input.float(),
+            self.normalized_shape,
+            self.weight.float() if self.weight is not None else None,
+            self.bias.float() if self.bias is not None else None,
+            self.eps,
+        )
+        return output.type_as(input)
 
 
 class Linear(nn.Linear):
@@ -112,7 +127,9 @@
                 qk = qk + mask[:n_ctx, :n_ctx]
             else:
                 mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-                min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
+                min_value = -float(
+                    "inf"
+                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                 qk = qk.masked_fill(mask, min_value)
 
         qk = qk.float()

--
Gitblit v1.9.1