| | |
| | | self.register_buffer("mask", mask, persistent=False) |
| | | |
| | | self.use_padmask = kwargs.get("use_padmask", True) |
| | | # def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): |
| | | # """ |
| | | # x : torch.LongTensor, shape = (batch_size, <= n_ctx) |
| | | # the text tokens |
| | | # xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) |
| | | # the encoded audio features to be attended on |
| | | # """ |
| | | # offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | # x = ( |
| | | # self.token_embedding(x) |
| | | # + self.positional_embedding[offset: offset + x.shape[-1]] |
| | | # ) |
| | | # x = x.to(xa.dtype) |
| | | # |
| | | # for block in self.blocks: |
| | | # x = block(x, xa, mask=self.mask, kv_cache=kv_cache) |
| | | # |
| | | # x = self.ln(x) |
| | | # logits = ( |
| | | # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) |
| | | # ).float() |
| | | # |
| | | # return logits |
| | | |
| | | |
| | | |
| | | def forward( |