funasr/bin/train.py
@@ -55,6 +55,8 @@ torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) # open tf32 torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True) local_rank = int(os.environ.get('LOCAL_RANK', 0)) if local_rank == 0: