| | |
| | | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool |
| | | ) |
| | | all_heads[self.dims.n_text_layer // 2 :] = True |
| | | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) |
| | | # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) |
| | | # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense() |
| | | # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False) |
| | | |
| | | def set_alignment_heads(self, dump: bytes): |
| | | array = np.frombuffer( |