From 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 11:27:39 +0800
Subject: [PATCH] Dev gzf exp (#1665)

---
 funasr/models/sond/attention.py |   34 ++++++++++++++--------------------
 1 files changed, 14 insertions(+), 20 deletions(-)

diff --git a/funasr/models/sond/attention.py b/funasr/models/sond/attention.py
index 290ab03..1af5534 100644
--- a/funasr/models/sond/attention.py
+++ b/funasr/models/sond/attention.py
@@ -17,6 +17,7 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 import funasr.models.lora.layers as lora
 
+
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
 
@@ -81,9 +82,7 @@
         n_batch = value.size(0)
         if mask is not None:
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
@@ -116,7 +115,6 @@
         q, k, v = self.forward_qkv(query, key, value)
         scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
         return self.forward_attention(v, scores, mask)
-
 
 
 class RelPositionMultiHeadedAttention(MultiHeadedAttention):
@@ -164,7 +162,7 @@
         x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
         x = x_padded[:, :, 1:].view_as(x)[
             :, :, :, : x.size(-1) // 2 + 1
-            ]  # only keep the positions from 0 to time2
+        ]  # only keep the positions from 0 to time2
 
         if self.zero_triu:
             ones = torch.ones((x.size(2), x.size(3)), device=x.device)
@@ -211,15 +209,9 @@
         matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
         matrix_bd = self.rel_shift(matrix_bd)
 
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
+        scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k)  # (batch, head, time1, time2)
 
         return self.forward_attention(v, scores, mask)
-
-
-
-
 
 
 class MultiHeadSelfAttention(nn.Module):
@@ -261,9 +253,15 @@
         b, t, d = x.size()
         q_k_v = self.linear_q_k_v(x)
         q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
+            1, 2
+        )  # (batch, head, time2, d_k)
 
         return q_h, k_h, v_h, v
 
@@ -287,9 +285,7 @@
 
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
 
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
+            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
@@ -324,5 +320,3 @@
         scores = torch.matmul(q_h, k_h.transpose(-2, -1))
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
         return att_outs
-
-

--
Gitblit v1.9.1