querryton
2024-04-20 01df8f330ccc754223d5e2d688dc0a55d27f2dcc
funasr/bin/train.py
@@ -55,6 +55,8 @@
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
    
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    if local_rank == 0:
@@ -102,7 +104,7 @@
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
    elif use_fsdp:
        # model = FSDP(model).cuda(local_rank)