From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/modules/attention.py | 81 ++++++++++++++++++++++++++++++++++++----
1 files changed, 72 insertions(+), 9 deletions(-)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 6202079..ab59493 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -13,6 +13,10 @@
from torch import nn
from typing import Optional, Tuple
+import torch.nn.functional as F
+from funasr.modules.nets_utils import make_pad_mask
+import funasr.modules.lora.layers as lora
+
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -318,7 +322,7 @@
"""
- def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionSANM, self).__init__()
assert n_feat % n_head == 0
@@ -328,8 +332,19 @@
# self.linear_q = nn.Linear(n_feat, n_feat)
# self.linear_k = nn.Linear(n_feat, n_feat)
# self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ if lora_list is not None:
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+ if lora_qkv_list == [False, False, False]:
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ else:
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
@@ -540,18 +555,32 @@
"""
- def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
+ def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionCrossAtt, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
- self.linear_q = nn.Linear(n_feat, n_feat)
- # self.linear_k = nn.Linear(n_feat, n_feat)
- # self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
- self.linear_out = nn.Linear(n_feat, n_feat)
+ if lora_list is not None:
+ if "q" in lora_list:
+ self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ lora_kv_list = ["k" in lora_list, "v" in lora_list]
+ if lora_kv_list == [False, False]:
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ else:
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
@@ -959,3 +988,37 @@
q, k, v = self.forward_qkv(query, key, value)
scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
+
+class CosineDistanceAttention(nn.Module):
+ """ Compute Cosine Distance between spk decoder output and speaker profile
+ Args:
+ profile_path: speaker profile file path (.npy file)
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, spk_decoder_out, profile, profile_lens=None):
+ """
+ Args:
+ spk_decoder_out(torch.Tensor):(B, L, D)
+ spk_profiles(torch.Tensor):(B, N, D)
+ """
+ x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D)
+ if profile_lens is not None:
+
+ mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+ )
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+ weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0) # (B, L, N)
+ else:
+ x = x[:, -1:, :, :]
+ weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+ weights = self.softmax(weights_not_softmax) # (B, 1, N)
+ spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D)
+
+ return spk_embedding, weights
--
Gitblit v1.9.1