From eaf9dda9e4d970af3d09db695e9e10c83ef94e25 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 四月 2024 15:05:37 +0800
Subject: [PATCH] Dev gzf exp (#1624)
---
funasr/models/sense_voice/whisper_lib/model.py | 27 ++++++++++++++++++++++-----
1 files changed, 22 insertions(+), 5 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 0e8f09b..ca960f1 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -74,7 +74,10 @@
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
+ **kwargs,
):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
+
q = self.query(x)
if kv_cache is None or xa is None or self.key not in kv_cache:
@@ -87,12 +90,13 @@
k = kv_cache[self.key]
v = kv_cache[self.value]
- wv, qk = self.qkv_attention(q, k, v, mask)
+ wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
return self.out(wv), qk
def qkv_attention(
- self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -101,10 +105,20 @@
qk = q @ k
if mask is not None:
- qk = qk + mask[:n_ctx, :n_ctx]
+ if not is_pad_mask:
+ qk = qk + mask[:n_ctx, :n_ctx]
+ else:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
+ )
+ qk = qk.masked_fill(mask, min_value)
+
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
+ if mask is not None and is_pad_mask:
+ w = w.masked_fill(mask, 0.0)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
@@ -132,10 +146,13 @@
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
+ **kwargs,
):
- x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+ is_pad_mask = kwargs.get("is_pad_mask", False)
+ is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
if self.cross_attn:
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
--
Gitblit v1.9.1