| | |
| | | x = self.relu(x) |
| | | x = self.linear2(x) |
| | | return x |
| | | |
| | | @tables.register("adaptor_classes", "QFormer") |
| | | class EncoderProjectorQFormer(nn.Module): |
| | | def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs): |
| | | super().__init__() |
| | | self.encoder_dim = encoder_dim |
| | | self.llm_dim = llm_dim |
| | | from transformers import Blip2QFormerConfig, Blip2QFormerModel |
| | | configuration = Blip2QFormerConfig() |
| | | configuration.encoder_hidden_size = self.encoder_dim |
| | | configuration.num_hidden_layers = 2 |
| | | |
| | | self.query_len = 64 |
| | | self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size)) |
| | | self.query.data.normal_(mean=0.0, std=1.0) |
| | | self.qformer = Blip2QFormerModel(configuration) |
| | | |
| | | self.linear = nn.Linear(configuration.hidden_size, self.llm_dim) |
| | | self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) |
| | | |
| | | def forward(self, x, atts): |
| | | query = self.query.expand(x.shape[0], -1, -1) |
| | | |
| | | query_output = self.qformer( |
| | | query_embeds=query, |
| | | encoder_hidden_states=x, |
| | | encoder_attention_mask=atts, |
| | | return_dict=True, |
| | | ) |
| | | |
| | | query_proj = self.norm(self.linear(query_output.last_hidden_state)) |
| | | |
| | | return query_proj |