From ba3a3bf4e67e861b833092d05d7c3842ea670cbc Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 28 五月 2024 14:53:18 +0800
Subject: [PATCH] Add files via upload
---
funasr/models/llm_asr/adaptor.py | 38 +++++++++++++++++++++++++++++++++++++-
1 files changed, 37 insertions(+), 1 deletions(-)
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 0676e7d..8c2a804 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -3,6 +3,7 @@
from funasr.register import tables
+
@tables.register("adaptor_classes", "Linear")
class Linear(nn.Module):
def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
@@ -20,10 +21,45 @@
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
-
+
x = x.contiguous()
x = x.view(batch_size, seq_len // self.k, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
+
+
+@tables.register("adaptor_classes", "QFormer")
+class EncoderProjectorQFormer(nn.Module):
+ def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
+ super().__init__()
+ self.encoder_dim = encoder_dim
+ self.llm_dim = llm_dim
+ from transformers import Blip2QFormerConfig, Blip2QFormerModel
+
+ configuration = Blip2QFormerConfig()
+ configuration.encoder_hidden_size = self.encoder_dim
+ configuration.num_hidden_layers = 2
+
+ self.query_len = 64
+ self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
+ self.query.data.normal_(mean=0.0, std=1.0)
+ self.qformer = Blip2QFormerModel(configuration)
+
+ self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
+ self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
+
+ def forward(self, x, atts):
+ query = self.query.expand(x.shape[0], -1, -1)
+
+ query_output = self.qformer(
+ query_embeds=query,
+ encoder_hidden_states=x,
+ encoder_attention_mask=atts,
+ return_dict=True,
+ )
+
+ query_proj = self.norm(self.linear(query_output.last_hidden_state))
+
+ return query_proj
--
Gitblit v1.9.1