From 0efc87352ce7d3903dbdedbfa5d01ca5e1cb19e7 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期四, 05 十二月 2024 15:15:38 +0800
Subject: [PATCH] Merge pull request #2267 from modelscope/dev_sx2
---
funasr/models/branchformer/fastformer.py | 30 ++++++++----------------------
1 files changed, 8 insertions(+), 22 deletions(-)
diff --git a/funasr/models/branchformer/fastformer.py b/funasr/models/branchformer/fastformer.py
index 24ca947..ede94dc 100644
--- a/funasr/models/branchformer/fastformer.py
+++ b/funasr/models/branchformer/fastformer.py
@@ -81,14 +81,11 @@
# (batch, n_heads, time)
query_for_score = (
- self.query_att(mixed_query_layer).transpose(1, 2)
- / self.attention_head_size**0.5
+ self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5
)
if mask is not None:
min_value = float(
- numpy.finfo(
- torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
- ).min
+ numpy.finfo(torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype).min
)
query_for_score = query_for_score.masked_fill(mask, min_value)
query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
@@ -108,9 +105,7 @@
pooled_query = self.dropout(pooled_query)
pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size)
- mixed_query_key_layer = (
- mixed_key_layer * pooled_query_repeat
- ) # (batch, time, size)
+ mixed_query_key_layer = mixed_key_layer * pooled_query_repeat # (batch, time, size)
# (batch, n_heads, time)
query_key_score = (
@@ -118,14 +113,10 @@
).transpose(1, 2)
if mask is not None:
min_value = float(
- numpy.finfo(
- torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
- ).min
+ numpy.finfo(torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype).min
)
query_key_score = query_key_score.masked_fill(mask, min_value)
- query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
- mask, 0.0
- )
+ query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(mask, 0.0)
else:
query_key_weight = torch.softmax(query_key_score, dim=-1)
@@ -133,9 +124,7 @@
key_layer = self.transpose_for_scores(
mixed_query_key_layer
) # (batch, n_heads, time, attn_dim)
- pooled_key = torch.matmul(
- query_key_weight, key_layer
- ) # (batch, n_heads, 1, attn_dim)
+ pooled_key = torch.matmul(query_key_weight, key_layer) # (batch, n_heads, 1, attn_dim)
pooled_key = self.dropout(pooled_key)
# NOTE: value = query, due to param sharing
@@ -143,11 +132,8 @@
1, 2
) # (batch, time, n_heads, attn_dim)
weighted_value = weighted_value.reshape(
- weighted_value.shape[:-2]
- + (self.num_attention_heads * self.attention_head_size,)
+ weighted_value.shape[:-2] + (self.num_attention_heads * self.attention_head_size,)
) # (batch, time, size)
- weighted_value = (
- self.dropout(self.transform(weighted_value)) + mixed_query_layer
- )
+ weighted_value = self.dropout(self.transform(weighted_value)) + mixed_query_layer
return weighted_value
--
Gitblit v1.9.1