| | |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, x, mask, mask_shift_chunk=None, mask_att_chunk_encoder=None): |
| | | def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
| | | q_h, k_h, v_h, v = self.forward_qkv(x) |
| | | fsmn_memory = self.forward_fsmn(v, mask[0], mask_shift_chunk) |
| | | fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk) |
| | | q_h = q_h * self.d_k ** (-0.5) |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder) |