| | |
| | | #!/usr/bin/env python3 |
| | | # -*- coding: utf-8 -*- |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | # Copyright 2019 Shigeki Karita |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """Multi-Head Attention layer definition.""" |
| | | |
| | | import math |
| | | |
| | | import numpy |
| | | import torch |
| | | from torch import nn |
| | | import torch.nn.functional as F |
| | | from typing import Optional, Tuple |
| | | |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionSANM |
| | | |
| | | |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM): |
| | |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder) |
| | | return att_outs + fsmn_memory |
| | | |
| | | |