雾聪
2023-06-28 54931dd4e1a099d7d6f144c4e12e5453deb3aa26
funasr/main_funcs/calculate_all_attentions.py
@@ -21,12 +21,12 @@
from funasr.modules.attention import MultiHeadedAttention
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
@torch.no_grad()
def calculate_all_attentions(
    model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
    model: FunASRModel, batch: Dict[str, torch.Tensor]
) -> Dict[str, List[torch.Tensor]]:
    """Derive the outputs from the all attention layers