old mode 100755
new mode 100644
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | import torch |
| | | import torch.nn as nn |
| | | import hydra |
| | | import logging |
| | | import time |
| | | import argparse |
| | | from io import BytesIO |
| | | |
| | | import torch |
| | | from contextlib import nullcontext |
| | | import torch.distributed as dist |
| | | from collections.abc import Sequence |
| | | from omegaconf import DictConfig, OmegaConf |
| | | from torch.cuda.amp import autocast, GradScaler |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | | from torch.distributed.algorithms.join import Join |
| | | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
| | | from funasr.train_utils.average_nbest_models import average_checkpoints |
| | | |
| | | from funasr.build_utils.build_args import build_args |
| | | from funasr.build_utils.build_dataloader import build_dataloader |
| | | from funasr.build_utils.build_distributed import build_distributed |
| | | from funasr.build_utils.build_model import build_model |
| | | from funasr.build_utils.build_optimizer import build_optimizer |
| | | from funasr.build_utils.build_scheduler import build_scheduler |
| | | from funasr.build_utils.build_trainer import build_trainer |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils.nested_dict_action import NestedDictAction |
| | | from funasr.utils.prepare_data import prepare_data |
| | | from funasr.utils.types import int_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | from funasr.register import tables |
| | | from funasr.optimizers import optim_classes |
| | | from funasr.train_utils.trainer import Trainer |
| | | from funasr.schedulers import scheduler_classes |
| | | from funasr.train_utils.initialize import initialize |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.models.lora.utils import mark_only_lora_as_trainable |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.utils.misc import prepare_model_dir |
| | | from funasr import AutoModel |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | | if kwargs.get("debug", False): |
| | | import pdb; pdb.set_trace() |
| | | |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) |
| | | kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) |
| | | |
| | | |
| | | main(**kwargs) |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = argparse.ArgumentParser( |
| | | description="FunASR Common Training Parser", |
| | | ) |
| | | |
| | | # common configuration |
| | | parser.add_argument("--output_dir", help="model save path") |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks") |
| | | |
| | | # ddp related |
| | | parser.add_argument( |
| | | "--dist_backend", |
| | | default="nccl", |
| | | type=str, |
| | | help="distributed backend", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_init_method", |
| | | type=str, |
| | | default="env://", |
| | | help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", ' |
| | | '"WORLD_SIZE", and "RANK" are referred.', |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_world_size", |
| | | type=int, |
| | | default=1, |
| | | help="number of nodes for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_rank", |
| | | type=int, |
| | | default=None, |
| | | help="node rank for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--local_rank", |
| | | type=int, |
| | | default=None, |
| | | help="local rank for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_master_addr", |
| | | default=None, |
| | | type=str_or_none, |
| | | help="The master address for distributed training. " |
| | | "This value is used when dist_init_method == 'env://'", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_master_port", |
| | | default=None, |
| | | type=int_or_none, |
| | | help="The master port for distributed training" |
| | | "This value is used when dist_init_method == 'env://'", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_launcher", |
| | | default=None, |
| | | type=str_or_none, |
| | | choices=["slurm", "mpi", None], |
| | | help="The launcher type for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--multiprocessing_distributed", |
| | | default=True, |
| | | type=str2bool, |
| | | help="Use multi-processing distributed training to launch " |
| | | "N processes per node, which has N GPUs. This is the " |
| | | "fastest way to use PyTorch for either single node or " |
| | | "multi node data parallel training", |
| | | ) |
| | | parser.add_argument( |
| | | "--unused_parameters", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Whether to use the find_unused_parameters in " |
| | | "torch.nn.parallel.DistributedDataParallel ", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | |
| | | # cudnn related |
| | | parser.add_argument( |
| | | "--cudnn_enabled", |
| | | type=str2bool, |
| | | default=torch.backends.cudnn.enabled, |
| | | help="Enable CUDNN", |
| | | ) |
| | | parser.add_argument( |
| | | "--cudnn_benchmark", |
| | | type=str2bool, |
| | | default=torch.backends.cudnn.benchmark, |
| | | help="Enable cudnn-benchmark mode", |
| | | ) |
| | | parser.add_argument( |
| | | "--cudnn_deterministic", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Enable cudnn-deterministic mode", |
| | | ) |
| | | |
| | | # trainer related |
| | | parser.add_argument( |
| | | "--max_epoch", |
| | | type=int, |
| | | default=40, |
| | | help="The maximum number epoch to train", |
| | | ) |
| | | parser.add_argument( |
| | | "--max_update", |
| | | type=int, |
| | | default=sys.maxsize, |
| | | help="The maximum number update step to train", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | parser.add_argument( |
| | | "--patience", |
| | | default=None, |
| | | help="Number of epochs to wait without improvement " |
| | | "before stopping the training", |
| | | ) |
| | | parser.add_argument( |
| | | "--val_scheduler_criterion", |
| | | type=str, |
| | | nargs=2, |
| | | default=("valid", "loss"), |
| | | help="The criterion used for the value given to the lr scheduler. " |
| | | 'Give a pair referring the phase, "train" or "valid",' |
| | | 'and the criterion name. The mode specifying "min" or "max" can ' |
| | | "be changed by --scheduler_conf", |
| | | ) |
| | | parser.add_argument( |
| | | "--early_stopping_criterion", |
| | | type=str, |
| | | nargs=3, |
| | | default=("valid", "loss", "min"), |
| | | help="The criterion used for judging of early stopping. " |
| | | 'Give a pair referring the phase, "train" or "valid",' |
| | | 'the criterion name and the mode, "min" or "max", e.g. "acc,max".', |
| | | ) |
| | | parser.add_argument( |
| | | "--best_model_criterion", |
| | | nargs="+", |
| | | default=[ |
| | | ("train", "loss", "min"), |
| | | ("valid", "loss", "min"), |
| | | ("train", "acc", "max"), |
| | | ("valid", "acc", "max"), |
| | | ], |
| | | help="The criterion used for judging of the best model. " |
| | | 'Give a pair referring the phase, "train" or "valid",' |
| | | 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".', |
| | | ) |
| | | parser.add_argument( |
| | | "--keep_nbest_models", |
| | | type=int, |
| | | nargs="+", |
| | | default=[10], |
| | | help="Remove previous snapshots excluding the n-best scored epochs", |
| | | ) |
| | | parser.add_argument( |
| | | "--nbest_averaging_interval", |
| | | type=int, |
| | | default=0, |
| | | help="The epoch interval to apply model averaging and save nbest models", |
| | | ) |
| | | parser.add_argument( |
| | | "--grad_clip", |
| | | type=float, |
| | | default=5.0, |
| | | help="Gradient norm threshold to clip", |
| | | ) |
| | | parser.add_argument( |
| | | "--grad_clip_type", |
| | | type=float, |
| | | default=2.0, |
| | | help="The type of the used p-norm for gradient clip. Can be inf", |
| | | ) |
| | | parser.add_argument( |
| | | "--grad_noise", |
| | | type=str2bool, |
| | | default=False, |
| | | help="The flag to switch to use noise injection to " |
| | | "gradients during training", |
| | | ) |
| | | parser.add_argument( |
| | | "--accum_grad", |
| | | type=int, |
| | | default=1, |
| | | help="The number of gradient accumulation", |
| | | ) |
| | | parser.add_argument( |
| | | "--resume", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Enable resuming if checkpoint is existing", |
| | | ) |
| | | parser.add_argument( |
| | | "--use_amp", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6", |
| | | ) |
| | | parser.add_argument( |
| | | "--log_interval", |
| | | default=None, |
| | | help="Show the logs every the number iterations in each epochs at the " |
| | | "training phase. If None is given, it is decided according the number " |
| | | "of training samples automatically .", |
| | | ) |
| | | |
| | | # pretrained model related |
| | | parser.add_argument( |
| | | "--init_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | help="Specify the file path used for initialization of parameters. " |
| | | "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', " |
| | | "where file_path is the model file path, " |
| | | "src_key specifies the key of model states to be used in the model file, " |
| | | "dst_key specifies the attribute of the model to be initialized, " |
| | | "and exclude_keys excludes keys of model states for the initialization." |
| | | "e.g.\n" |
| | | " # Load all parameters" |
| | | " --init_param some/where/model.pb\n" |
| | | " # Load only decoder parameters" |
| | | " --init_param some/where/model.pb:decoder:decoder\n" |
| | | " # Load only decoder parameters excluding decoder.embed" |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n" |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n", |
| | | ) |
| | | parser.add_argument( |
| | | "--ignore_init_mismatch", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Ignore size mismatch when loading pre-trained model", |
| | | ) |
| | | parser.add_argument( |
| | | "--freeze_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | help="Freeze parameters", |
| | | ) |
| | | |
| | | # dataset related |
| | | parser.add_argument( |
| | | "--dataset_type", |
| | | type=str, |
| | | default="small", |
| | | help="whether to use dataloader for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--dataset_conf", |
| | | action=NestedDictAction, |
| | | default=dict(), |
| | | help=f"The keyword arguments for dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_file", |
| | | type=str, |
| | | default=None, |
| | | help="train_list for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_file", |
| | | type=str, |
| | | default=None, |
| | | help="valid_list for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | default=[], |
| | | help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | parser.add_argument( |
| | | "--train_shape_file", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_shape_file", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | parser.add_argument( |
| | | "--use_preprocessor", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Apply preprocessing to data or not", |
| | | ) |
| | | |
| | | # optimization related |
| | | parser.add_argument( |
| | | "--optim", |
| | | type=lambda x: x.lower(), |
| | | default="adam", |
| | | help="The optimizer type", |
| | | ) |
| | | parser.add_argument( |
| | | "--optim_conf", |
| | | action=NestedDictAction, |
| | | default=dict(), |
| | | help="The keyword arguments for optimizer", |
| | | ) |
| | | parser.add_argument( |
| | | "--scheduler", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The lr scheduler type", |
| | | ) |
| | | parser.add_argument( |
| | | "--scheduler_conf", |
| | | action=NestedDictAction, |
| | | default=dict(), |
| | | help="The keyword arguments for lr scheduler", |
| | | ) |
| | | |
| | | # most task related |
| | | parser.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | parser.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="A text mapping int-id to token", |
| | | ) |
| | | parser.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="bpe", |
| | | choices=["bpe", "char", "word"], |
| | | help="", |
| | | ) |
| | | parser.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model file fo sentencepiece", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese"], |
| | | default=None, |
| | | help="Apply text cleaning", |
| | | ) |
| | | parser.add_argument( |
| | | "--g2p", |
| | | type=str_or_none, |
| | | choices=g2p_choices, |
| | | default=None, |
| | | help="Specify g2p method if --token_type=phn", |
| | | ) |
| | | |
| | | # pai related |
| | | parser.add_argument( |
| | | "--use_pai", |
| | | type=str2bool, |
| | | default=False, |
| | | help="flag to indicate whether training on PAI", |
| | | ) |
| | | parser.add_argument( |
| | | "--simple_ddp", |
| | | type=str2bool, |
| | | default=False, |
| | | ) |
| | | parser.add_argument( |
| | | "--num_worker_count", |
| | | type=int, |
| | | default=1, |
| | | help="The number of machines on PAI.", |
| | | ) |
| | | parser.add_argument( |
| | | "--access_key_id", |
| | | type=str, |
| | | default=None, |
| | | help="The username for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--access_key_secret", |
| | | type=str, |
| | | default=None, |
| | | help="The password for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--endpoint", |
| | | type=str, |
| | | default=None, |
| | | help="The endpoint for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--bucket_name", |
| | | type=str, |
| | | default=None, |
| | | help="The bucket name for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--oss_bucket", |
| | | default=None, |
| | | help="oss bucket.", |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | parser = get_parser() |
| | | args, extra_task_params = parser.parse_known_args() |
| | | if extra_task_params: |
| | | args = build_args(args, parser, extra_task_params) |
| | | |
| | | def main(**kwargs): |
| | | |
| | | # set random seed |
| | | set_all_random_seed(args.seed) |
| | | torch.backends.cudnn.enabled = args.cudnn_enabled |
| | | torch.backends.cudnn.benchmark = args.cudnn_benchmark |
| | | torch.backends.cudnn.deterministic = args.cudnn_deterministic |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | |
| | | local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| | | if local_rank == 0: |
| | | tables.print() |
| | | # Check if we are using DDP or FSDP |
| | | use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1 |
| | | use_fsdp = kwargs.get("use_fsdp", False) |
| | | # use_ddp = False if use_fsdp else use_fsdp |
| | | if use_ddp or use_fsdp: |
| | | dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://') |
| | | torch.cuda.set_device(local_rank) |
| | | |
| | | # ddp init |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | args.distributed = args.ngpu > 1 or args.dist_world_size > 1 |
| | | distributed_option = build_distributed(args) |
| | | logging.info("Build model, frontend, tokenizer") |
| | | device = kwargs.get("device", "cuda") |
| | | kwargs["device"] = "cpu" |
| | | model = AutoModel(**kwargs) |
| | | |
| | | |
| | | # save config.yaml |
| | | if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0: |
| | | prepare_model_dir(**kwargs) |
| | | |
| | | # parse kwargs |
| | | kwargs = model.kwargs |
| | | kwargs["device"] = device |
| | | tokenizer = kwargs["tokenizer"] |
| | | frontend = kwargs["frontend"] |
| | | model = model.model |
| | | del kwargs["model"] |
| | | |
| | | # for logging |
| | | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | | logging.basicConfig( |
| | | level="INFO", |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | # freeze_param |
| | | freeze_param = kwargs.get("freeze_param", None) |
| | | if freeze_param is not None: |
| | | freeze_param = eval(freeze_param) |
| | | if isinstance(freeze_param, Sequence): |
| | | freeze_param = (freeze_param,) |
| | | logging.info("freeze_param is not None: %s", freeze_param) |
| | | for t in freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |
| | | logging.info(f"Setting {k}.requires_grad = False") |
| | | p.requires_grad = False |
| | | |
| | | |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | elif use_fsdp: |
| | | # model = FSDP(model).cuda(local_rank) |
| | | |
| | | def custom_auto_wrap_policy( |
| | | module: nn.Module, |
| | | recurse: bool, |
| | | nonwrapped_numel: int, |
| | | # Additional custom arguments |
| | | min_num_params: int = int(1e8), |
| | | ) -> bool: |
| | | # 根据自定义逻辑决定是否包装模块 |
| | | is_large = unwrapped_params >= min_num_params |
| | | requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1 |
| | | return is_large and requires_grad_uniform |
| | | |
| | | # Configure a custom `min_num_params` |
| | | my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) |
| | | torch.cuda.set_device(local_rank) |
| | | model = FSDP(model, |
| | | auto_wrap_policy=custom_auto_wrap_policy, |
| | | mixed_precision=None, |
| | | device_id=torch.cuda.current_device()) |
| | | else: |
| | | logging.basicConfig( |
| | | level="ERROR", |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | model = model.to(device=kwargs.get("device", "cuda")) |
| | | |
| | | logging.info(f"{model}") |
| | | kwargs["device"] = next(model.parameters()).device |
| | | |
| | | # optim |
| | | logging.info("Build optim") |
| | | optim = kwargs.get("optim", "adam") |
| | | assert optim in optim_classes |
| | | optim_class = optim_classes.get(optim) |
| | | optim = optim_class(model.parameters(), **kwargs.get("optim_conf")) |
| | | |
| | | # scheduler |
| | | logging.info("Build scheduler") |
| | | scheduler = kwargs.get("scheduler", "warmuplr") |
| | | assert scheduler in scheduler_classes |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | |
| | | # dataset |
| | | logging.info("Build dataloader") |
| | | dataloader_class = tables.dataloader_classes.get( kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")) |
| | | dataloader_tr, dataloader_val = dataloader_class(**kwargs) |
| | | |
| | | trainer = Trainer(local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | | device=kwargs["device"], |
| | | output_dir=kwargs.get("output_dir", "./exp"), |
| | | **kwargs.get("train_conf"), |
| | | ) |
| | | |
| | | scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None |
| | | scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler |
| | | |
| | | trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler) |
| | | |
| | | tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard") |
| | | os.makedirs(tensorboard_dir, exist_ok=True) |
| | | try: |
| | | from tensorboardX import SummaryWriter |
| | | writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None |
| | | except: |
| | | writer = None |
| | | |
| | | if use_ddp or use_fsdp: |
| | | context = Join([model]) |
| | | else: |
| | | context = nullcontext() |
| | | |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | | with context: |
| | | |
| | | trainer.train_epoch( |
| | | model=model, |
| | | optim=optim, |
| | | scheduler=scheduler, |
| | | scaler=scaler, |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | scheduler.step() |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | |
| | | # prepare files for dataloader |
| | | prepare_data(args, distributed_option) |
| | | |
| | | trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler) |
| | | |
| | | model = build_model(args) |
| | | optimizers = build_optimizer(args, model=model) |
| | | schedulers = build_scheduler(args, optimizers) |
| | | time2 = time.perf_counter() |
| | | time_escaped = (time2 - time1) / 3600.0 |
| | | logging.info( |
| | | f"rank: {local_rank}, " |
| | | f"time_escaped_epoch: {time_escaped:.3f} hours, " |
| | | f"estimated to finish {trainer.max_epoch} " |
| | | f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n") |
| | | |
| | | logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | logging.info(pytorch_cudnn_version()) |
| | | logging.info(model_summary(model)) |
| | | logging.info("Optimizer: {}".format(optimizers)) |
| | | logging.info("Scheduler: {}".format(schedulers)) |
| | | |
| | | # dump args to config.yaml |
| | | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | | os.makedirs(args.output_dir, exist_ok=True) |
| | | with open(os.path.join(args.output_dir, "config.yaml"), "w") as f: |
| | | logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml")) |
| | | if args.use_pai: |
| | | buffer = BytesIO() |
| | | torch.save({"config": vars(args)}, buffer) |
| | | args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue()) |
| | | else: |
| | | yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False) |
| | | if trainer.rank == 0: |
| | | average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list) |
| | | |
| | | # dataloader for training/validation |
| | | train_dataloader, valid_dataloader = build_dataloader(args) |
| | | trainer.close() |
| | | |
| | | # Trainer, including model, optimizers, etc. |
| | | trainer = build_trainer( |
| | | args=args, |
| | | model=model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | | train_dataloader=train_dataloader, |
| | | valid_dataloader=valid_dataloader, |
| | | distributed_option=distributed_option |
| | | ) |
| | | |
| | | trainer.run() |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main_hydra() |