| | |
| | | att_outs = self.forward_attention(v_h, scores, None) |
| | | return att_outs + fsmn_memory, cache |
| | | |
| | | class MultiHeadedAttentionSANMExport(nn.Module): |
| | | def __init__(self, model): |
| | | super().__init__() |
| | | self.d_k = model.d_k |
| | | self.h = model.h |
| | | self.linear_out = model.linear_out |
| | | self.linear_q_k_v = model.linear_q_k_v |
| | | self.fsmn_block = model.fsmn_block |
| | | self.pad_fn = model.pad_fn |
| | | |
| | | self.attn = None |
| | | self.all_head_size = self.h * self.d_k |
| | | |
| | | def forward(self, x, mask): |
| | | mask_3d_btd, mask_4d_bhlt = mask |
| | | q_h, k_h, v_h, v = self.forward_qkv(x) |
| | | fsmn_memory = self.forward_fsmn(v, mask_3d_btd) |
| | | q_h = q_h * self.d_k**(-0.5) |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt) |
| | | return att_outs + fsmn_memory |
| | | |
| | | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: |
| | | new_x_shape = x.size()[:-1] + (self.h, self.d_k) |
| | | x = x.view(new_x_shape) |
| | | return x.permute(0, 2, 1, 3) |
| | | |
| | | def forward_qkv(self, x): |
| | | q_k_v = self.linear_q_k_v(x) |
| | | q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) |
| | | q_h = self.transpose_for_scores(q) |
| | | k_h = self.transpose_for_scores(k) |
| | | v_h = self.transpose_for_scores(v) |
| | | return q_h, k_h, v_h, v |
| | | |
| | | def forward_fsmn(self, inputs, mask): |
| | | # b, t, d = inputs.size() |
| | | # mask = torch.reshape(mask, (b, -1, 1)) |
| | | inputs = inputs * mask |
| | | x = inputs.transpose(1, 2) |
| | | x = self.pad_fn(x) |
| | | x = self.fsmn_block(x) |
| | | x = x.transpose(1, 2) |
| | | x = x + inputs |
| | | x = x * mask |
| | | return x |
| | | |
| | | def forward_attention(self, value, scores, mask): |
| | | scores = scores + mask |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) |
| | | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | | context_layer = context_layer.view(new_context_layer_shape) |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |