funasr/bin/train.py
@@ -77,6 +77,12 @@ help="Whether to use the find_unused_parameters in " "torch.nn.parallel.DistributedDataParallel ", ) parser.add_argument( "--gpu_id", type=int, default=0, help="local gpu id.", ) # cudnn related parser.add_argument( @@ -399,6 +405,7 @@ torch.backends.cudnn.deterministic = args.cudnn_deterministic # ddp init os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) args.distributed = args.dist_world_size > 1 distributed_option = build_distributed(args)