| | |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |
| | | |
| | | def preprocess_for_attn(x, mask, cache, pad_fn): |
| | | x = x * mask |
| | | x = x.transpose(1, 2) |
| | | if cache is None: |
| | | x = pad_fn(x) |
| | | else: |
| | | x = torch.cat((cache[:, :, 1:], x), dim=2) |
| | | cache = x |
| | | return x, cache |
| | | |
| | | |
| | | import torch.fx |
| | | torch.fx.wrap('preprocess_for_attn') |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | def __init__(self, model): |
| | | super().__init__() |
| | |
| | | self.attn = None |
| | | |
| | | def forward(self, inputs, mask, cache=None): |
| | | # b, t, d = inputs.size() |
| | | # mask = torch.reshape(mask, (b, -1, 1)) |
| | | inputs = inputs * mask |
| | | |
| | | x = inputs.transpose(1, 2) |
| | | if cache is None: |
| | | x = self.pad_fn(x) |
| | | else: |
| | | x = torch.cat((cache[:, :, 1:], x), dim=2) |
| | | cache = x |
| | | x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn) |
| | | x = self.fsmn_block(x) |
| | | x = x.transpose(1, 2) |
| | | |
| | |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | | context_layer = context_layer.view(new_context_layer_shape) |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |
| | | |