| | |
| | | |
| | | |
| | | class TargetDelayTransformer(AbsPunctuation): |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | |
| | | ) |
| | | self.decoder = nn.Linear(att_unit, punc_size) |
| | | |
| | | |
| | | # def _target_mask(self, ys_in_pad): |
| | | # ys_mask = ys_in_pad != 0 |
| | | # m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0) |
| | | # return ys_mask.unsqueeze(-2) & m |
| | | |
| | | |
| | | def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | """Compute loss value from buffer sequences. |
| | |
| | | y = self.decoder(h) |
| | | return y, None |
| | | |
| | | def score( |
| | | self, y: torch.Tensor, state: Any, x: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, Any]: |
| | | def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: |
| | | """Score new token. |
| | | |
| | | Args: |
| | |
| | | |
| | | """ |
| | | y = y.unsqueeze(0) |
| | | h, _, cache = self.encoder.forward_one_step( |
| | | self.embed(y), self._target_mask(y), cache=state |
| | | ) |
| | | h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) |
| | | h = self.decoder(h[:, -1]) |
| | | logp = h.log_softmax(dim=-1).squeeze(0) |
| | | return logp, cache |
| | | |
| | | def batch_score( |
| | | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, List[Any]]: |
| | | def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: |
| | | """Score new token batch. |
| | | |
| | | Args: |
| | |
| | | batch_state = None |
| | | else: |
| | | # transpose state of [batch, layer] into [layer, batch] |
| | | batch_state = [ |
| | | torch.stack([states[b][i] for b in range(n_batch)]) |
| | | for i in range(n_layers) |
| | | ] |
| | | batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] |
| | | |
| | | # batch decoding |
| | | h, _, states = self.encoder.forward_one_step( |
| | | self.embed(ys), self._target_mask(ys), cache=batch_state |
| | | ) |
| | | h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) |
| | | h = self.decoder(h[:, -1]) |
| | | logp = h.log_softmax(dim=-1) |
| | | |