| | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | import funasr.models.lora.layers as lora |
| | | |
| | | |
| | | class MultiHeadedAttention(nn.Module): |
| | | """Multi-Head Attention layer. |
| | | |
| | |
| | | n_batch = value.size(0) |
| | | if mask is not None: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float( |
| | | numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min |
| | | ) |
| | | min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) |
| | | scores = scores.masked_fill(mask, min_value) |
| | | self.attn = torch.softmax(scores, dim=-1).masked_fill( |
| | | mask, 0.0 |
| | |
| | | q, k, v = self.forward_qkv(query, key, value) |
| | | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| | | return self.forward_attention(v, scores, mask) |
| | | |
| | | |
| | | |
| | | class RelPositionMultiHeadedAttention(MultiHeadedAttention): |
| | |
| | | x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) |
| | | x = x_padded[:, :, 1:].view_as(x)[ |
| | | :, :, :, : x.size(-1) // 2 + 1 |
| | | ] # only keep the positions from 0 to time2 |
| | | ] # only keep the positions from 0 to time2 |
| | | |
| | | if self.zero_triu: |
| | | ones = torch.ones((x.size(2), x.size(3)), device=x.device) |
| | |
| | | matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| | | matrix_bd = self.rel_shift(matrix_bd) |
| | | |
| | | scores = (matrix_ac + matrix_bd) / math.sqrt( |
| | | self.d_k |
| | | ) # (batch, head, time1, time2) |
| | | scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) |
| | | |
| | | return self.forward_attention(v, scores, mask) |
| | | |
| | | |
| | | |
| | | |
| | | |
| | | |
| | | class MultiHeadSelfAttention(nn.Module): |
| | |
| | | b, t, d = x.size() |
| | | q_k_v = self.linear_q_k_v(x) |
| | | q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) |
| | | q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k) |
| | | k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) |
| | | v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k) |
| | | q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( |
| | | 1, 2 |
| | | ) # (batch, head, time1, d_k) |
| | | k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( |
| | | 1, 2 |
| | | ) # (batch, head, time2, d_k) |
| | | v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( |
| | | 1, 2 |
| | | ) # (batch, head, time2, d_k) |
| | | |
| | | return q_h, k_h, v_h, v |
| | | |
| | |
| | | |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | |
| | | min_value = float( |
| | | numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min |
| | | ) |
| | | min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) |
| | | scores = scores.masked_fill(mask, min_value) |
| | | self.attn = torch.softmax(scores, dim=-1).masked_fill( |
| | | mask, 0.0 |
| | |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) |
| | | return att_outs |
| | | |
| | | |