From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0
---
funasr/modules/attention.py | 83 ++++++++++++++++++++++++++++++++++++++++-
1 files changed, 80 insertions(+), 3 deletions(-)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index f01e340..b007d58 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -338,7 +338,10 @@
else:
self.linear_out = nn.Linear(n_feat, n_feat)
lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
- self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ if lora_qkv_list == [False, False, False]:
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ else:
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
else:
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
@@ -453,6 +456,44 @@
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs + fsmn_memory
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ if chunk_size is not None and look_back > 0 or look_back == -1:
+ if cache is not None:
+ k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
+ v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
+ k_h = torch.cat((cache["k"], k_h), dim=2)
+ v_h = torch.cat((cache["v"], v_h), dim=2)
+
+ cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
+ cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
+ if look_back != -1:
+ cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
+ cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
+ else:
+ cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
+ "v": v_h[:, :, :-(chunk_size[2]), :]}
+ cache = cache_tmp
+ fsmn_memory = self.forward_fsmn(v, None)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, None)
+ return att_outs + fsmn_memory, cache
+
+
class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -562,11 +603,18 @@
if lora_list is not None:
if "q" in lora_list:
self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
lora_kv_list = ["k" in lora_list, "v" in lora_list]
- self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
- r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if lora_kv_list == [False, False]:
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ else:
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
if "o" in lora_list:
self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
else:
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
@@ -657,6 +705,35 @@
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
return self.forward_attention(v_h, scores, memory_mask)
+ def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h = self.forward_qkv(x, memory)
+ if chunk_size is not None and look_back > 0:
+ if cache is not None:
+ k_h = torch.cat((cache["k"], k_h), dim=2)
+ v_h = torch.cat((cache["v"], v_h), dim=2)
+ cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
+ cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
+ else:
+ cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
+ "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
+ cache = cache_tmp
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ return self.forward_attention(v_h, scores, None), cache
+
class MultiHeadSelfAttention(nn.Module):
"""Multi-Head Attention layer.
--
Gitblit v1.9.1