| | |
| | | |
| | | |
| | | class TargetDelayTransformer(AbsPunctuation): |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | |
| | | ): |
| | | super().__init__() |
| | | if pos_enc == "sinusoidal": |
| | | # pos_enc_class = PositionalEncoding |
| | | # pos_enc_class = PositionalEncoding |
| | | pos_enc_class = SinusoidalPositionEncoder |
| | | elif pos_enc is None: |
| | | |
| | |
| | | num_blocks=layer, |
| | | dropout_rate=dropout_rate, |
| | | input_layer="pe", |
| | | # pos_enc_class=pos_enc_class, |
| | | # pos_enc_class=pos_enc_class, |
| | | padding_idx=0, |
| | | ) |
| | | self.decoder = nn.Linear(att_unit, punc_size) |
| | | |
| | | |
| | | # def _target_mask(self, ys_in_pad): |
| | | # ys_mask = ys_in_pad != 0 |
| | | # m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0) |
| | | # return ys_mask.unsqueeze(-2) & m |
| | | |
| | | |
| | | def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | """Compute loss value from buffer sequences. |
| | |
| | | |
| | | """ |
| | | x = self.embed(input) |
| | | # mask = self._target_mask(input) |
| | | # mask = self._target_mask(input) |
| | | h, _, _ = self.encoder(x, text_lengths) |
| | | y = self.decoder(h) |
| | | return y, None |
| | | |
| | | def score( |
| | | self, y: torch.Tensor, state: Any, x: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, Any]: |
| | | def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: |
| | | """Score new token. |
| | | |
| | | Args: |
| | |
| | | |
| | | """ |
| | | y = y.unsqueeze(0) |
| | | h, _, cache = self.encoder.forward_one_step( |
| | | self.embed(y), self._target_mask(y), cache=state |
| | | ) |
| | | h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) |
| | | h = self.decoder(h[:, -1]) |
| | | logp = h.log_softmax(dim=-1).squeeze(0) |
| | | return logp, cache |
| | | |
| | | def batch_score( |
| | | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, List[Any]]: |
| | | def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: |
| | | """Score new token batch. |
| | | |
| | | Args: |
| | |
| | | batch_state = None |
| | | else: |
| | | # transpose state of [batch, layer] into [layer, batch] |
| | | batch_state = [ |
| | | torch.stack([states[b][i] for b in range(n_batch)]) |
| | | for i in range(n_layers) |
| | | ] |
| | | batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)] |
| | | |
| | | # batch decoding |
| | | h, _, states = self.encoder.forward_one_step( |
| | | self.embed(ys), self._target_mask(ys), cache=batch_state |
| | | ) |
| | | h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) |
| | | h = self.decoder(h[:, -1]) |
| | | logp = h.log_softmax(dim=-1) |
| | | |