游雁
2025-02-13 604ae30fdbe96185282e6c83134e11217f3acd20
funasr/bin/train_ds.py
@@ -27,7 +27,7 @@
from funasr.train_utils.trainer_ds import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.initialize import initialize
from funasr.download.download_from_hub import download_model
from funasr.download.download_model_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
@@ -81,8 +81,13 @@
        deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
    elif use_ddp or use_fsdp:
        logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
        dist.init_process_group(
            backend=kwargs.get("backend", "nccl"),
            init_method="env://",
        )
        torch.cuda.set_device(local_rank)
    # rank = dist.get_rank()
    logging.info("Build model, frontend, tokenizer")
    device = kwargs.get("device", "cuda")
@@ -124,11 +129,12 @@
        use_ddp=use_ddp,
        use_fsdp=use_fsdp,
        device=kwargs["device"],
        excludes=kwargs.get("excludes", None),
        output_dir=kwargs.get("output_dir", "./exp"),
        **kwargs.get("train_conf"),
    )
    model = trainer.warp_model(model)
    model = trainer.warp_model(model, **kwargs)
    kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0))
    trainer.device = int(os.environ.get("LOCAL_RANK", 0))
@@ -143,7 +149,7 @@
    dataloader = dataloader_class(**kwargs)
    # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
    scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None
    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
    trainer.resume_checkpoint(
@@ -178,14 +184,17 @@
            )
            trainer.start_step = 0
            torch.cuda.empty_cache()
            device = next(model.parameters()).device
            if device.type == "cuda":
                with torch.cuda.device(device):
                    torch.cuda.empty_cache()
            time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
            logging.info(
                f"rank: {local_rank}, "
                f"\n\nrank: {local_rank}, "
                f"time_escaped_epoch: {time_escaped:.3f} hours, "
                f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours"
                f"epoch: {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
                f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {dataloader.data_split_num-data_split_i} slices, {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours, "
                f"epoch: {trainer.max_epoch - epoch} epochs, {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n"
            )
        trainer.start_data_split_i = 0
@@ -199,7 +208,7 @@
        time2 = time.perf_counter()
        time_escaped = (time2 - time1) / 3600.0
        logging.info(
            f"rank: {local_rank}, "
            f"\n\nrank: {local_rank}, "
            f"time_escaped_epoch: {time_escaped:.3f} hours, "
            f"estimated to finish {trainer.max_epoch} "
            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"