funasr/export/models/modules/multihead_att.py @@ -75,6 +75,8 @@ return x, cache torch_version = float(".".join(torch.__version__.split(".")[:2])) if torch_version >= 1.8: import torch.fx torch.fx.wrap('preprocess_for_attn')