嘉渊
2023-04-25 7f7c23c36fab81f1f64e0ecdeb22875960b7200f
funasr/bin/train.py
@@ -242,6 +242,12 @@
        help="Enable resuming if checkpoint is existing",
    )
    parser.add_argument(
        "--train_dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type for training.",
    )
    parser.add_argument(
        "--use_amp",
        type=str2bool,
        default=False,
@@ -253,6 +259,12 @@
        help="Show the logs every the number iterations in each epochs at the "
             "training phase. If None is given, it is decided according the number "
             "of training samples automatically .",
    )
    parser.add_argument(
        "--use_tensorboard",
        type=str2bool,
        default=True,
        help="Enable tensorboard logging",
    )
    # pretrained model related
@@ -508,6 +520,10 @@
    prepare_data(args, distributed_option)
    model = build_model(args)
    model = model.to(
        dtype=getattr(torch, args.train_dtype),
        device="cuda" if args.ngpu > 0 else "cpu",
    )
    optimizers = build_optimizer(args, model=model)
    schedulers = build_scheduler(args, optimizers)
@@ -515,6 +531,7 @@
                                                                   distributed_option.dist_rank,
                                                                   distributed_option.local_rank))
    logging.info(pytorch_cudnn_version())
    logging.info("Args: {}".format(args))
    logging.info(model_summary(model))
    logging.info("Optimizer: {}".format(optimizers))
    logging.info("Scheduler: {}".format(schedulers))