From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/models/sense_voice/whisper_lib/model.py |   31 +++++++++++++++++++++++++------
 1 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 0e8f09b..5f7caeb 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -74,7 +74,10 @@
         xa: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         kv_cache: Optional[dict] = None,
+        **kwargs,
     ):
+        is_pad_mask = kwargs.get("is_pad_mask", False)
+
         q = self.query(x)
 
         if kv_cache is None or xa is None or self.key not in kv_cache:
@@ -87,12 +90,13 @@
             k = kv_cache[self.key]
             v = kv_cache[self.value]
 
-        wv, qk = self.qkv_attention(q, k, v, mask)
+        wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
         return self.out(wv), qk
 
     def qkv_attention(
-        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
     ):
+        is_pad_mask = kwargs.get("is_pad_mask", False)
         n_batch, n_ctx, n_state = q.shape
         scale = (n_state // self.n_head) ** -0.25
         q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
@@ -101,10 +105,20 @@
 
         qk = q @ k
         if mask is not None:
-            qk = qk + mask[:n_ctx, :n_ctx]
+            if not is_pad_mask:
+                qk = qk + mask[:n_ctx, :n_ctx]
+            else:
+                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+                min_value = float(
+                    np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
+                )
+                qk = qk.masked_fill(mask, min_value)
+                
         qk = qk.float()
 
         w = F.softmax(qk, dim=-1).to(q.dtype)
+        if mask is not None and is_pad_mask:
+            w = w.masked_fill(mask, 0.0)
         return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
 
 
@@ -132,10 +146,13 @@
         xa: Optional[Tensor] = None,
         mask: Optional[Tensor] = None,
         kv_cache: Optional[dict] = None,
+        **kwargs,
     ):
-        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+        is_pad_mask = kwargs.get("is_pad_mask", False)
+        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
         if self.cross_attn:
-            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
         x = x + self.mlp(self.mlp_ln(x))
         return x
 
@@ -244,7 +261,9 @@
             self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
         )
         all_heads[self.dims.n_text_layer // 2 :] = True
-        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+        # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+        # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
+        # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
 
     def set_alignment_heads(self, dump: bytes):
         array = np.frombuffer(

--
Gitblit v1.9.1