| | |
| | | |
| | | return x |
| | | |
| | | def init_state(self, x): |
| | | state = {} |
| | | |
| | | return state |
| | | |
| | | def final_score(self, state) -> float: |
| | | """Score eos (optional). |
| | | |
| | | Args: |
| | | state: Scorer state for prefix tokens |
| | | |
| | | Returns: |
| | | float: final score |
| | | |
| | | """ |
| | | return 0.0 |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state) |
| | | return logp.squeeze(0)[-1, :], state |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | """Multi-Head Attention layer. |
| | |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | |
| | | fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None |
| | | fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None |
| | | # if fsmn_cache is not None: |
| | | # x = x[:, -1:] |
| | | att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache) |