| | |
| | | |
| | | return x |
| | | |
| | | def init_state(self, x): |
| | | state = {} |
| | | |
| | | return state |
| | | |
| | | def final_score(self, state) -> float: |
| | | """Score eos (optional). |
| | | |
| | | Args: |
| | | state: Scorer state for prefix tokens |
| | | |
| | | Returns: |
| | | float: final score |
| | | |
| | | """ |
| | | return 0.0 |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) |
| | | return logp.squeeze(0)[-1, :], state |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | """Multi-Head Attention layer. |