From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/sense_voice/whisper_lib/model.py | 98 ++++++++++++++++++++++++------------------------
1 files changed, 49 insertions(+), 49 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index ca960f1..3d0d6a8 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -27,9 +27,24 @@
n_text_layer: int
+# class LayerNorm(nn.LayerNorm):
+# def forward(self, x: Tensor) -> Tensor:
+# return super().forward(x.float()).type(x.dtype)
+
+
class LayerNorm(nn.LayerNorm):
- def forward(self, x: Tensor) -> Tensor:
- return super().forward(x.float()).type(x.dtype)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
class Linear(nn.Linear):
@@ -42,9 +57,7 @@
class Conv1d(nn.Conv1d):
- def _conv_forward(
- self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
- ) -> Tensor:
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
@@ -94,7 +107,12 @@
return self.out(wv), qk
def qkv_attention(
- self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
+ self,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ mask: Optional[Tensor] = None,
+ **kwargs,
):
is_pad_mask = kwargs.get("is_pad_mask", False)
n_batch, n_ctx, n_state = q.shape
@@ -109,11 +127,11 @@
qk = qk + mask[:n_ctx, :n_ctx]
else:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = float(
- np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
- )
+ min_value = -float(
+ "inf"
+ ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
qk = qk.masked_fill(mask, min_value)
-
+
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
@@ -129,15 +147,11 @@
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
- self.cross_attn = (
- MultiHeadAttention(n_state, n_head) if cross_attention else None
- )
+ self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
- self.mlp = nn.Sequential(
- Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
- )
+ self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp_ln = LayerNorm(n_state)
def forward(
@@ -152,15 +166,18 @@
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
if self.cross_attn:
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
+ x = (
+ x
+ + self.cross_attn(
+ self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
+ )[0]
+ )
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
- def __init__(
- self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
- ):
+ def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
@@ -184,7 +201,6 @@
# x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
-
for block in self.blocks:
x = block(x)
@@ -193,19 +209,14 @@
class TextDecoder(nn.Module):
- def __init__(
- self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
- ):
+ def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
- [
- ResidualAttentionBlock(n_state, n_head, cross_attention=True)
- for _ in range(n_layer)
- ]
+ [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
)
self.ln = LayerNorm(n_state)
@@ -220,19 +231,14 @@
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
- x = (
- self.token_embedding(x)
- + self.positional_embedding[offset : offset + x.shape[-1]]
- )
+ x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
- logits = (
- x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
- ).float()
+ logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
return logits
@@ -257,19 +263,15 @@
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
- all_heads = torch.zeros(
- self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
- )
+ all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2 :] = True
- self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+ # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+ # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
+ # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
def set_alignment_heads(self, dump: bytes):
- array = np.frombuffer(
- gzip.decompress(base64.b85decode(dump)), dtype=bool
- ).copy()
- mask = torch.from_numpy(array).reshape(
- self.dims.n_text_layer, self.dims.n_text_head
- )
+ array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
+ mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
@@ -278,9 +280,7 @@
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
- def forward(
- self, mel: torch.Tensor, tokens: torch.Tensor
- ) -> Dict[str, torch.Tensor]:
+ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
@@ -330,4 +330,4 @@
detect_language = detect_language_function
transcribe = transcribe_function
- decode = decode_function
\ No newline at end of file
+ decode = decode_function
--
Gitblit v1.9.1