From b512846c2ca0cb0e28b1cea6c9980b2d04e1d7ae Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 10:51:20 +0800
Subject: [PATCH] batch

---
 funasr/models/llm_asr/adaptor.py |   19 +++++++++++--------
 1 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 2093588..8c2a804 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -3,6 +3,7 @@
 
 from funasr.register import tables
 
+
 @tables.register("adaptor_classes", "Linear")
 class Linear(nn.Module):
     def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
@@ -20,13 +21,14 @@
         if num_frames_to_discard > 0:
             x = x[:, :-num_frames_to_discard, :]
         seq_len = x.size(1)
-        
+
         x = x.contiguous()
         x = x.view(batch_size, seq_len // self.k, dim * self.k)
         x = self.linear1(x)
         x = self.relu(x)
         x = self.linear2(x)
         return x
+
 
 @tables.register("adaptor_classes", "QFormer")
 class EncoderProjectorQFormer(nn.Module):
@@ -35,28 +37,29 @@
         self.encoder_dim = encoder_dim
         self.llm_dim = llm_dim
         from transformers import Blip2QFormerConfig, Blip2QFormerModel
+
         configuration = Blip2QFormerConfig()
         configuration.encoder_hidden_size = self.encoder_dim
         configuration.num_hidden_layers = 2
-        
+
         self.query_len = 64
         self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
         self.query.data.normal_(mean=0.0, std=1.0)
         self.qformer = Blip2QFormerModel(configuration)
-        
+
         self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
         self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
-    
+
     def forward(self, x, atts):
         query = self.query.expand(x.shape[0], -1, -1)
-        
+
         query_output = self.qformer(
             query_embeds=query,
             encoder_hidden_states=x,
             encoder_attention_mask=atts,
             return_dict=True,
         )
-        
+
         query_proj = self.norm(self.linear(query_output.last_hidden_state))
-        
-        return query_proj
\ No newline at end of file
+
+        return query_proj

--
Gitblit v1.9.1