zhifu gao
2024-04-17 eaf9dda9e4d970af3d09db695e9e10c83ef94e25
funasr/bin/train.py
@@ -102,7 +102,7 @@
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
    elif use_fsdp:
        # model = FSDP(model).cuda(local_rank)