| | |
| | | mask_3d_btd, mask_4d_bhlt = mask |
| | | q_h, k_h, v_h, v = self.forward_qkv(x) |
| | | fsmn_memory = self.forward_fsmn(v, mask_3d_btd) |
| | | q_h = q_h * self.d_k**(-0.5) |
| | | q_h = q_h * self.d_k ** (-0.5) |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt) |
| | | return att_outs + fsmn_memory |
| | |
| | | x = pad_fn(x) |
| | | else: |
| | | x = torch.cat((cache, x), dim=2) |
| | | cache = x[:, :, -(kernel_size-1):] |
| | | cache = x[:, :, -(kernel_size - 1) :] |
| | | return x, cache |
| | | |
| | | |
| | | torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]]) |
| | | if torch_version >= (1, 8): |
| | | import torch.fx |
| | | torch.fx.wrap('preprocess_for_attn') |
| | | |
| | | torch.fx.wrap("preprocess_for_attn") |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoderExport(nn.Module): |
| | |
| | | self.linear_out = model.linear_out |
| | | self.attn = None |
| | | self.all_head_size = self.h * self.d_k |
| | | |
| | | |
| | | def forward(self, query, key, value, mask): |
| | | q, k, v = self.forward_qkv(query, key, value) |
| | | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| | |
| | | k = self.transpose_for_scores(k) |
| | | v = self.transpose_for_scores(v) |
| | | return q, k, v |
| | | |
| | | |
| | | def forward_attention(self, value, scores, mask): |
| | | scores = scores + mask |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) |
| | | |
| | | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | | context_layer = context_layer.view(new_context_layer_shape) |
| | |
| | | self.linear_pos = model.linear_pos |
| | | self.pos_bias_u = model.pos_bias_u |
| | | self.pos_bias_v = model.pos_bias_v |
| | | |
| | | |
| | | def forward(self, query, key, value, pos_emb, mask): |
| | | q, k, v = self.forward_qkv(query, key, value) |
| | | q = q.transpose(1, 2) # (batch, time1, head, d_k) |
| | | |
| | | p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k) |
| | | p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k) |
| | | |
| | | # (batch, head, time1, d_k) |
| | | q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
| | |
| | | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| | | matrix_bd = self.rel_shift(matrix_bd) |
| | | |
| | | scores = (matrix_ac + matrix_bd) / math.sqrt( |
| | | self.d_k |
| | | ) # (batch, head, time1, time2) |
| | | scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) |
| | | |
| | | return self.forward_attention(v, scores, mask) |
| | | |
| | |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) |
| | | |
| | | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | | context_layer = context_layer.view(new_context_layer_shape) |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |