嘉渊
2023-04-28 f97e0eb9eee3f14c410ce905b73d0c83033dc1c9
funasr/main_funcs/calculate_all_attentions.py
@@ -21,12 +21,12 @@
from funasr.modules.attention import MultiHeadedAttention
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
@torch.no_grad()
def calculate_all_attentions(
    model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
    model: FunASRModel, batch: Dict[str, torch.Tensor]
) -> Dict[str, List[torch.Tensor]]:
    """Derive the outputs from the all attention layers