From 8b0fb74bded1f8a162e6c0e94c3522be6216ea03 Mon Sep 17 00:00:00 2001
From: chengligen <101448376+chengligen@users.noreply.github.com>
Date: 星期一, 26 五月 2025 14:11:33 +0800
Subject: [PATCH] feat: add 'words' key aligned with timestamps in sensevoice model output (#2531)

---
 funasr/models/sond/attention.py |   46 ++++++++++++++++++++--------------------------
 1 files changed, 20 insertions(+), 26 deletions(-)

diff --git a/funasr/models/sond/attention.py b/funasr/models/sond/attention.py
index 290ab03..18580b7 100644
--- a/funasr/models/sond/attention.py
+++ b/funasr/models/sond/attention.py
@@ -17,6 +17,7 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 import funasr.models.lora.layers as lora
 
+
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
 
@@ -81,17 +82,15 @@
         n_batch = value.size(0)
         if mask is not None:
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+            attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
             )  # (batch, head, time1, time2)
         else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
 
-        p_attn = self.dropout(self.attn)
+        p_attn = self.dropout(attn)
         x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
         x = (
             x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -116,7 +115,6 @@
         q, k, v = self.forward_qkv(query, key, value)
         scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
         return self.forward_attention(v, scores, mask)
-
 
 
 class RelPositionMultiHeadedAttention(MultiHeadedAttention):
@@ -164,7 +162,7 @@
         x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
         x = x_padded[:, :, 1:].view_as(x)[
             :, :, :, : x.size(-1) // 2 + 1
-            ]  # only keep the positions from 0 to time2
+        ]  # only keep the positions from 0 to time2
 
         if self.zero_triu:
             ones = torch.ones((x.size(2), x.size(3)), device=x.device)
@@ -211,15 +209,9 @@
         matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
         matrix_bd = self.rel_shift(matrix_bd)
 
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
+        scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)  # (batch, head, time1, time2)
 
         return self.forward_attention(v, scores, mask)
-
-
-
-
 
 
 class MultiHeadSelfAttention(nn.Module):
@@ -261,9 +253,15 @@
         b, t, d = x.size()
         q_k_v = self.linear_q_k_v(x)
         q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
 
         return q_h, k_h, v_h, v
 
@@ -287,17 +285,15 @@
 
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
 
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+            attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
             )  # (batch, head, time1, time2)
         else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+            attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
 
-        p_attn = self.dropout(self.attn)
+        p_attn = self.dropout(attn)
         x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
         x = (
             x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
@@ -324,5 +320,3 @@
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
         return att_outs
-
-

--
Gitblit v1.9.1