| | |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | if self.unified_model_training: |
| | | x, mask = self.embed(x, mask, self.default_chunk_size) |
| | | else: |
| | | x, mask = self.embed(x, mask) |
| | | pos_enc = self.pos_enc(x) |
| | | |
| | | if self.unified_model_training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | |
| | | else: |
| | | chunk_size = (chunk_size % self.short_chunk_size) + 1 |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | |
| | | device=x.device, |
| | | ) |
| | | else: |
| | | x, mask = self.embed(x, mask, None) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = None |
| | | x = self.encoders( |
| | | x, |