From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/whisper_lid/decoder.py | 31 +++++++++++++------------------
1 files changed, 13 insertions(+), 18 deletions(-)
diff --git a/funasr/models/whisper_lid/decoder.py b/funasr/models/whisper_lid/decoder.py
index 4db9205..1c8ab47 100644
--- a/funasr/models/whisper_lid/decoder.py
+++ b/funasr/models/whisper_lid/decoder.py
@@ -29,9 +29,7 @@
super().__init__()
assert whisper_model in whisper.available_models()
- _model = whisper.load_model(
- whisper_model, download_root=download_dir, device="cpu"
- )
+ _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
self.decoders = copy.deepcopy(_model.decoder)
attention_dim = self.decoders.token_embedding.embedding_dim
@@ -67,10 +65,7 @@
olens: (batch, )
"""
tgt, memory = ys_in_pad, hs_pad
- tgt = (
- self.decoders.token_embedding(tgt)
- + self.decoders.positional_embedding[: tgt.size(1)]
- )
+ tgt = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
tgt = self.dropout(tgt)
x = tgt.to(memory.dtype)
@@ -81,15 +76,20 @@
memory_mask = None
for layer, block in enumerate(self.decoders.blocks):
- x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+ x = block(
+ x,
+ memory,
+ mask=self.decoders.mask,
+ memory_mask=memory_mask,
+ is_pad_mask=False,
+ is_pad_memory_mask=True,
+ )
if layer < len(self.decoders.blocks) - 1:
x = self.dropout(x)
x = self.decoders.ln(x)
- x = (
- x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
- ).float()
+ x = (x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
return x, ys_in_lens
@@ -116,10 +116,7 @@
cache implementation is ignored for now
for simplicity & correctness
"""
- x = (
- self.decoders.token_embedding(tgt)
- + self.decoders.positional_embedding[: tgt.size(1)]
- )
+ x = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
x = self.dropout(x)
x = x.to(memory.dtype)
@@ -130,9 +127,7 @@
x = self.decoders.ln(x)
y = x[:, -1]
- y = (
- y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)
- ).float()
+ y = (y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
y = torch.log_softmax(y, dim=-1)
return y, None
--
Gitblit v1.9.1