From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/whisper_lid/decoder.py |   31 +++++++++++++------------------
 1 files changed, 13 insertions(+), 18 deletions(-)

diff --git a/funasr/models/whisper_lid/decoder.py b/funasr/models/whisper_lid/decoder.py
index 4db9205..1c8ab47 100644
--- a/funasr/models/whisper_lid/decoder.py
+++ b/funasr/models/whisper_lid/decoder.py
@@ -29,9 +29,7 @@
         super().__init__()
 
         assert whisper_model in whisper.available_models()
-        _model = whisper.load_model(
-            whisper_model, download_root=download_dir, device="cpu"
-        )
+        _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
         self.decoders = copy.deepcopy(_model.decoder)
         attention_dim = self.decoders.token_embedding.embedding_dim
 
@@ -67,10 +65,7 @@
             olens: (batch, )
         """
         tgt, memory = ys_in_pad, hs_pad
-        tgt = (
-            self.decoders.token_embedding(tgt)
-            + self.decoders.positional_embedding[: tgt.size(1)]
-        )
+        tgt = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
         tgt = self.dropout(tgt)
 
         x = tgt.to(memory.dtype)
@@ -81,15 +76,20 @@
             memory_mask = None
 
         for layer, block in enumerate(self.decoders.blocks):
-            x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+            x = block(
+                x,
+                memory,
+                mask=self.decoders.mask,
+                memory_mask=memory_mask,
+                is_pad_mask=False,
+                is_pad_memory_mask=True,
+            )
 
             if layer < len(self.decoders.blocks) - 1:
                 x = self.dropout(x)
 
         x = self.decoders.ln(x)
-        x = (
-            x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
-        ).float()
+        x = (x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
 
         return x, ys_in_lens
 
@@ -116,10 +116,7 @@
             cache implementation is ignored for now
             for simplicity & correctness
         """
-        x = (
-            self.decoders.token_embedding(tgt)
-            + self.decoders.positional_embedding[: tgt.size(1)]
-        )
+        x = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
         x = self.dropout(x)
         x = x.to(memory.dtype)
 
@@ -130,9 +127,7 @@
 
         x = self.decoders.ln(x)
         y = x[:, -1]
-        y = (
-            y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
-        ).float()
+        y = (y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
         y = torch.log_softmax(y, dim=-1)
 
         return y, None

--
Gitblit v1.9.1