From 0efc87352ce7d3903dbdedbfa5d01ca5e1cb19e7 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期四, 05 十二月 2024 15:15:38 +0800
Subject: [PATCH] Merge pull request #2267 from modelscope/dev_sx2

---
 funasr/models/branchformer/fastformer.py |   30 ++++++++----------------------
 1 files changed, 8 insertions(+), 22 deletions(-)

diff --git a/funasr/models/branchformer/fastformer.py b/funasr/models/branchformer/fastformer.py
index 24ca947..ede94dc 100644
--- a/funasr/models/branchformer/fastformer.py
+++ b/funasr/models/branchformer/fastformer.py
@@ -81,14 +81,11 @@
 
         # (batch, n_heads, time)
         query_for_score = (
-            self.query_att(mixed_query_layer).transpose(1, 2)
-            / self.attention_head_size**0.5
+            self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5
         )
         if mask is not None:
             min_value = float(
-                numpy.finfo(
-                    torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
-                ).min
+                numpy.finfo(torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype).min
             )
             query_for_score = query_for_score.masked_fill(mask, min_value)
             query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
@@ -108,9 +105,7 @@
         pooled_query = self.dropout(pooled_query)
         pooled_query_repeat = pooled_query.repeat(1, seq_len, 1)  # (batch, time, size)
 
-        mixed_query_key_layer = (
-            mixed_key_layer * pooled_query_repeat
-        )  # (batch, time, size)
+        mixed_query_key_layer = mixed_key_layer * pooled_query_repeat  # (batch, time, size)
 
         # (batch, n_heads, time)
         query_key_score = (
@@ -118,14 +113,10 @@
         ).transpose(1, 2)
         if mask is not None:
             min_value = float(
-                numpy.finfo(
-                    torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
-                ).min
+                numpy.finfo(torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype).min
             )
             query_key_score = query_key_score.masked_fill(mask, min_value)
-            query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
-                mask, 0.0
-            )
+            query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(mask, 0.0)
         else:
             query_key_weight = torch.softmax(query_key_score, dim=-1)
 
@@ -133,9 +124,7 @@
         key_layer = self.transpose_for_scores(
             mixed_query_key_layer
         )  # (batch, n_heads, time, attn_dim)
-        pooled_key = torch.matmul(
-            query_key_weight, key_layer
-        )  # (batch, n_heads, 1, attn_dim)
+        pooled_key = torch.matmul(query_key_weight, key_layer)  # (batch, n_heads, 1, attn_dim)
         pooled_key = self.dropout(pooled_key)
 
         # NOTE: value = query, due to param sharing
@@ -143,11 +132,8 @@
             1, 2
         )  # (batch, time, n_heads, attn_dim)
         weighted_value = weighted_value.reshape(
-            weighted_value.shape[:-2]
-            + (self.num_attention_heads * self.attention_head_size,)
+            weighted_value.shape[:-2] + (self.num_attention_heads * self.attention_head_size,)
         )  # (batch, time, size)
-        weighted_value = (
-            self.dropout(self.transform(weighted_value)) + mixed_query_layer
-        )
+        weighted_value = self.dropout(self.transform(weighted_value)) + mixed_query_layer
 
         return weighted_value

--
Gitblit v1.9.1