funasr/main_funcs/calculate_all_attentions.py
@@ -21,12 +21,12 @@ from funasr.modules.attention import MultiHeadedAttention from funasr.train.abs_espnet_model import AbsESPnetModel from funasr.models.base_model import FunASRModel @torch.no_grad() def calculate_all_attentions( model: AbsESPnetModel, batch: Dict[str, torch.Tensor] model: FunASRModel, batch: Dict[str, torch.Tensor] ) -> Dict[str, List[torch.Tensor]]: """Derive the outputs from the all attention layers