From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sense_voice/whisper_lib/model.py |   98 ++++++++++++++++++++++++------------------------
 1 files changed, 49 insertions(+), 49 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index ca960f1..3d0d6a8 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -27,9 +27,24 @@
     n_text_layer: int
 
 
+# class LayerNorm(nn.LayerNorm):
+#     def forward(self, x: Tensor) -> Tensor:
+#         return super().forward(x.float()).type(x.dtype)
+
+
 class LayerNorm(nn.LayerNorm):
-    def forward(self, x: Tensor) -> Tensor:
-        return super().forward(x.float()).type(x.dtype)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, input):
+        output = F.layer_norm(
+            input.float(),
+            self.normalized_shape,
+            self.weight.float() if self.weight is not None else None,
+            self.bias.float() if self.bias is not None else None,
+            self.eps,
+        )
+        return output.type_as(input)
 
 
 class Linear(nn.Linear):
@@ -42,9 +57,7 @@
 
 
 class Conv1d(nn.Conv1d):
-    def _conv_forward(
-        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
-    ) -> Tensor:
+    def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
         return super()._conv_forward(
             x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
         )
@@ -94,7 +107,12 @@
         return self.out(wv), qk
 
     def qkv_attention(
-        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
+        self,
+        q: Tensor,
+        k: Tensor,
+        v: Tensor,
+        mask: Optional[Tensor] = None,
+        **kwargs,
     ):
         is_pad_mask = kwargs.get("is_pad_mask", False)
         n_batch, n_ctx, n_state = q.shape
@@ -109,11 +127,11 @@
                 qk = qk + mask[:n_ctx, :n_ctx]
             else:
                 mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-                min_value = float(
-                    np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
-                )
+                min_value = -float(
+                    "inf"
+                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                 qk = qk.masked_fill(mask, min_value)
-                
+
         qk = qk.float()
 
         w = F.softmax(qk, dim=-1).to(q.dtype)
@@ -129,15 +147,11 @@
         self.attn = MultiHeadAttention(n_state, n_head)
         self.attn_ln = LayerNorm(n_state)
 
-        self.cross_attn = (
-            MultiHeadAttention(n_state, n_head) if cross_attention else None
-        )
+        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
         self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
 
         n_mlp = n_state * 4
-        self.mlp = nn.Sequential(
-            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
-        )
+        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
         self.mlp_ln = LayerNorm(n_state)
 
     def forward(
@@ -152,15 +166,18 @@
         is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
         x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
         if self.cross_attn:
-            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
+            x = (
+                x
+                + self.cross_attn(
+                    self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
+                )[0]
+            )
         x = x + self.mlp(self.mlp_ln(x))
         return x
 
 
 class AudioEncoder(nn.Module):
-    def __init__(
-        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
-    ):
+    def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
         super().__init__()
         self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
         self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@@ -184,7 +201,6 @@
         # x = (x + self.positional_embedding).to(x.dtype)
         x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
 
-
         for block in self.blocks:
             x = block(x)
 
@@ -193,19 +209,14 @@
 
 
 class TextDecoder(nn.Module):
-    def __init__(
-        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
-    ):
+    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
         super().__init__()
 
         self.token_embedding = nn.Embedding(n_vocab, n_state)
         self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
 
         self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
-            [
-                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
-                for _ in range(n_layer)
-            ]
+            [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
         )
         self.ln = LayerNorm(n_state)
 
@@ -220,19 +231,14 @@
             the encoded audio features to be attended on
         """
         offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
-        x = (
-            self.token_embedding(x)
-            + self.positional_embedding[offset : offset + x.shape[-1]]
-        )
+        x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
         x = x.to(xa.dtype)
 
         for block in self.blocks:
             x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
 
         x = self.ln(x)
-        logits = (
-            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
-        ).float()
+        logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
 
         return logits
 
@@ -257,19 +263,15 @@
         )
         # use the last half among the decoder layers for time alignment by default;
         # to use a specific set of heads, see `set_alignment_heads()` below.
-        all_heads = torch.zeros(
-            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
-        )
+        all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
         all_heads[self.dims.n_text_layer // 2 :] = True
-        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+        # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+        # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
+        # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
 
     def set_alignment_heads(self, dump: bytes):
-        array = np.frombuffer(
-            gzip.decompress(base64.b85decode(dump)), dtype=bool
-        ).copy()
-        mask = torch.from_numpy(array).reshape(
-            self.dims.n_text_layer, self.dims.n_text_head
-        )
+        array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
+        mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
         self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
 
     def embed_audio(self, mel: torch.Tensor):
@@ -278,9 +280,7 @@
     def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
         return self.decoder(tokens, audio_features)
 
-    def forward(
-        self, mel: torch.Tensor, tokens: torch.Tensor
-    ) -> Dict[str, torch.Tensor]:
+    def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
         return self.decoder(tokens, self.encoder(mel))
 
     @property
@@ -330,4 +330,4 @@
 
     detect_language = detect_language_function
     transcribe = transcribe_function
-    decode = decode_function
\ No newline at end of file
+    decode = decode_function

--
Gitblit v1.9.1