| | |
| | | super().__init__() |
| | | |
| | | assert whisper_model in whisper.available_models() |
| | | _model = whisper.load_model( |
| | | whisper_model, download_root=download_dir, device="cpu" |
| | | ) |
| | | _model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu") |
| | | self.decoders = copy.deepcopy(_model.decoder) |
| | | attention_dim = self.decoders.token_embedding.embedding_dim |
| | | |
| | |
| | | olens: (batch, ) |
| | | """ |
| | | tgt, memory = ys_in_pad, hs_pad |
| | | tgt = ( |
| | | self.decoders.token_embedding(tgt) |
| | | + self.decoders.positional_embedding[: tgt.size(1)] |
| | | ) |
| | | tgt = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)] |
| | | tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | |
| | | memory_mask = None |
| | | |
| | | for layer, block in enumerate(self.decoders.blocks): |
| | | x = block(x, memory, mask=self.decoders.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True) |
| | | x = block( |
| | | x, |
| | | memory, |
| | | mask=self.decoders.mask, |
| | | memory_mask=memory_mask, |
| | | is_pad_mask=False, |
| | | is_pad_memory_mask=True, |
| | | ) |
| | | |
| | | if layer < len(self.decoders.blocks) - 1: |
| | | x = self.dropout(x) |
| | | |
| | | x = self.decoders.ln(x) |
| | | x = ( |
| | | x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | x = (x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return x, ys_in_lens |
| | | |
| | |
| | | cache implementation is ignored for now |
| | | for simplicity & correctness |
| | | """ |
| | | x = ( |
| | | self.decoders.token_embedding(tgt) |
| | | + self.decoders.positional_embedding[: tgt.size(1)] |
| | | ) |
| | | x = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)] |
| | | x = self.dropout(x) |
| | | x = x.to(memory.dtype) |
| | | |
| | |
| | | |
| | | x = self.decoders.ln(x) |
| | | y = x[:, -1] |
| | | y = ( |
| | | y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | y = (y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | y = torch.log_softmax(y, dim=-1) |
| | | |
| | | return y, None |