jmwang66
2023-05-16 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906
funasr/main_funcs/calculate_all_attentions.py
@@ -21,12 +21,12 @@
from funasr.modules.attention import MultiHeadedAttention
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
@torch.no_grad()
def calculate_all_attentions(
    model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
    model: FunASRModel, batch: Dict[str, torch.Tensor]
) -> Dict[str, List[torch.Tensor]]:
    """Derive the outputs from the all attention layers