| | |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |
| | | |
| | | def preprocess_for_attn(x, mask, cache, pad_fn): |
| | | def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size): |
| | | x = x * mask |
| | | x = x.transpose(1, 2) |
| | | if cache is None: |
| | | x = pad_fn(x) |
| | | else: |
| | | x = torch.cat((cache[:, :, 1:], x), dim=2) |
| | | cache = x |
| | | x = torch.cat((cache, x), dim=2) |
| | | cache = x[:, :, -(kernel_size-1):] |
| | | return x, cache |
| | | |
| | | |
| | |
| | | self.attn = None |
| | | |
| | | def forward(self, inputs, mask, cache=None): |
| | | x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn) |
| | | x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size) |
| | | x = self.fsmn_block(x) |
| | | x = x.transpose(1, 2) |
| | | |