From f68ae892be3e59c20a033f18d5f61db7f633801f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 10:32:34 +0800
Subject: [PATCH] deepspeed
---
funasr/bin/train_ds.py | 38 -------
funasr/train_utils/trainer_ds.py | 229 +++++++++++++++++++++++----------------------
2 files changed, 119 insertions(+), 148 deletions(-)
diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
index e4db533..032a0cf 100644
--- a/funasr/bin/train_ds.py
+++ b/funasr/bin/train_ds.py
@@ -133,29 +133,7 @@
kwargs["device"] = next(model.parameters()).device
trainer.device = kwargs["device"]
- # optim
- logging.info("Build optim")
- optim = kwargs.get("optim", "adam")
- assert optim in optim_classes
- optim_class = optim_classes.get(optim)
- optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
-
- # scheduler
- logging.info("Build scheduler")
- scheduler = kwargs.get("scheduler", "warmuplr")
- assert scheduler in scheduler_classes
- scheduler_class = scheduler_classes.get(scheduler)
- scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
-
- if use_deepspeed:
- args = OmegaConf.create({"deepspeed_config": kwargs.get("deepspeed_config", "")})
- model, optimizer, _, scheduler = deepspeed.initialize(
- args=args,
- model=model,
- optimizer=optim,
- lr_scheduler=scheduler,
- model_parameters=model.parameters(),
- )
+ model, optim, scheduler = trainer.warp_optim_scheduler(model, kwargs)
# dataset
logging.info("Build dataloader")
@@ -175,15 +153,6 @@
scaler=scaler,
)
- tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
- os.makedirs(tensorboard_dir, exist_ok=True)
- try:
- from tensorboardX import SummaryWriter
-
- writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
- except:
- writer = None
-
dataloader_tr, dataloader_val = None, None
for epoch in range(trainer.start_epoch, trainer.max_epoch):
time1 = time.perf_counter()
@@ -201,7 +170,6 @@
dataloader_train=dataloader_tr,
dataloader_val=dataloader_val,
epoch=epoch,
- writer=writer,
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
start_step=trainer.start_step,
@@ -211,9 +179,7 @@
torch.cuda.empty_cache()
trainer.start_data_split_i = 0
- trainer.validate_epoch(
- model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
- )
+ trainer.validate_epoch(model=model, dataloader_val=dataloader_val, epoch=epoch + 1)
scheduler.step()
trainer.step_in_epoch = 0
trainer.save_checkpoint(
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 7188921..f3a0309 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -128,6 +128,14 @@
job_type="training",
reinit=True,
)
+ tensorboard_dir = os.path.join(output_dir, "tensorboard")
+ os.makedirs(tensorboard_dir, exist_ok=True)
+ try:
+ from tensorboardX import SummaryWriter
+
+ self.writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
+ except:
+ self.writer = None
def save_checkpoint(
self,
@@ -331,7 +339,6 @@
dataloader_train=None,
dataloader_val=None,
epoch=None,
- writer=None,
**kwargs,
):
"""
@@ -356,14 +363,19 @@
time_beg = time.perf_counter()
time5 = time_beg
for batch_idx, batch in enumerate(dataloader_train):
- if self.use_ddp or self.use_fsdp:
- dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
- if iterator_stop > 0:
- break
+ loss_dict = {
+ "speed_stats": {},
+ "epoch": epoch,
+ "batch_idx": batch_idx,
+ "data_split_i": kwargs.get("data_split_i", 0),
+ "data_split_num": kwargs.get("data_split_num", 1),
+ "log_step": batch_idx + kwargs.get("start_step", 0),
+ }
+
self.batch_total += 1
self.step_in_epoch += 1
time1 = time.perf_counter()
- speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
+ loss_dict["speed_stats"]["data_load"] = f"{time1-time_beg:0.3f}"
batch = to_device(batch, self.device)
@@ -372,28 +384,28 @@
my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
with my_context():
time2 = time.perf_counter()
- loss_dict = {}
+
self.forward_step(model, batch, loss_dict=loss_dict)
time3 = time.perf_counter()
- speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+ loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}"
self.backward_step(model, scaler, loss_dict=loss_dict)
time4 = time.perf_counter()
- speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
-
- # self.train_loss_avg = (
- # self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
- # + loss.detach().cpu().item()
- # ) / (batch_idx + kwargs.get("start_step", 0) + 1)
- # if "acc" in stats:
- # self.train_acc_avg = (
- # self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
- # + stats["acc"].detach().cpu().item()
- # ) / (batch_idx + kwargs.get("start_step", 0) + 1)
+ loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
self.update_step(model, optim, scheduler, scaler, loss_dict)
- # Perform an optimizer step only after accumulating enough gradients
+ total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
+ time5 = time.perf_counter()
+
+ loss_dict["speed_stats"]["optim_time"] = f"{time5 - time4:0.3f}"
+
+ loss_dict["speed_stats"]["total_time"] = total_time
+
+ loss_dict["lr"] = scheduler.get_last_lr()[0]
+ loss_dict["batch_num_epoch"] = len(dataloader_train)
+
+ self.log(loss_dict, tag="train")
if self.step_in_epoch % self.validate_interval == 0:
self.validate_epoch(
@@ -434,28 +446,8 @@
with maybe_autocast(self.use_fp16):
retval = model(**batch)
- if (
- self.reset_gpu_cache
- and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
- ):
- torch.cuda.empty_cache()
-
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
- # if self.use_ddp or self.use_fsdp:
- # # Apply weighted averaging for loss and stats
- # loss = (loss * weight.type(loss.dtype)).sum()
- # # if distributed, this method can also apply all_reduce()
- # # stats, weight = recursive_average(stats, weight, distributed=True)
- # if self.use_ddp or self.use_fsdp:
- # dist.all_reduce(weight, op=dist.ReduceOp.SUM)
- # # Now weight is summation over all workers
- # loss /= weight.sum() # shape:[1] -> shape:[]
- # # Multiply world_size because DistributedDataParallel
- # # automatically normalizes the gradient by world_size.
- # loss *= self.world_size
- # loss *= self.world_size
- # Scale the loss since we're not updating for every mini-batch
loss_dict["loss"] = loss
loss_dict["stats"] = stats
@@ -474,68 +466,36 @@
loss.backward()
def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
- if (batch_idx + 1) % self.accum_grad == 0:
- # Perform gradient clipping if it is set
- if self.grad_clip > 0:
- grad_norm = torch.nn.utils.clip_grad_norm_(
- model.parameters(),
- max_norm=self.grad_clip,
- norm_type=self.grad_clip_type,
- )
- if not torch.isfinite(grad_norm):
- logging.warning(f"The grad norm is {grad_norm}. Skipping updating the model.")
- optim.zero_grad() # Reset gradients
- return
- # Execute an optimization step (update model parameters)
- if self.use_ddp or self.use_fsdp:
- dist.barrier()
- if self.use_fp16:
- scaler.step(optim)
- scaler.update()
- else:
- optim.step()
- scheduler.step()
- # Clear gradients for the next accumulation stage
- optim.zero_grad(set_to_none=True)
+ if self.use_deepspeed:
+ model.step()
+ else:
+ if (batch_idx + 1) % self.accum_grad == 0:
+ # Perform gradient clipping if it is set
+ if self.grad_clip > 0:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ model.parameters(),
+ max_norm=self.grad_clip,
+ norm_type=self.grad_clip_type,
+ )
+ if not torch.isfinite(grad_norm):
+ logging.warning(
+ f"The grad norm is {grad_norm}. Skipping updating the model."
+ )
+ optim.zero_grad() # Reset gradients
+ return
- if self.use_ddp or self.use_fsdp:
- train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
- self.device
- )
- train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
- self.device
- )
- dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
- self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
- self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
-
- total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
- time5 = time.perf_counter()
-
- speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
-
- speed_stats["total_time"] = total_time
- lr = scheduler.get_last_lr()[0]
- batch_num_epoch = 1
- if hasattr(dataloader_train, "__len__"):
- batch_num_epoch = len(dataloader_train)
- self.log(
- epoch,
- batch_idx,
- log_step=batch_idx + kwargs.get("start_step", 0),
- step_in_epoch=self.step_in_epoch,
- batch_num_epoch=batch_num_epoch,
- lr=lr,
- loss=loss.detach().cpu().item(),
- speed_stats=speed_stats,
- stats=stats,
- writer=writer,
- tag="train",
- data_split_i=kwargs.get("data_split_i", 0),
- data_split_num=kwargs.get("data_split_num", 1),
- )
+ # Execute an optimization step (update model parameters)
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+ if self.use_fp16:
+ scaler.step(optim)
+ scaler.update()
+ else:
+ optim.step()
+ scheduler.step()
+ # Clear gradients for the next accumulation stage
+ optim.zero_grad(set_to_none=True)
def validate_epoch(
self,
@@ -646,21 +606,22 @@
def log(
self,
- epoch=0,
- batch_idx=0,
- step_in_epoch=0,
- batch_num_epoch=-1,
- lr=0.0,
- loss=0.0,
- speed_stats=None,
- stats=None,
- writer=None,
+ loss_dict: dict = None,
tag="train",
- data_split_i=0,
- data_split_num=1,
- log_step=None,
**kwargs,
):
+ loss = loss_dict["loss"].detach().cpu().item()
+ epoch = loss_dict["epoch"]
+ batch_idx = loss_dict["batch_idx"]
+ step_in_epoch = loss_dict["step_in_epoch"]
+ batch_num_epoch = loss_dict["batch_num_epoch"]
+ lr = loss_dict["lr"]
+
+ speed_stats = loss_dict["speed_stats"]
+ stats = loss_dict["stats"]
+ data_split_i = loss_dict["data_split_i"]
+ data_split_num = loss_dict["data_split_num"]
+ log_step = loss_dict.get("log_step", None)
if (batch_idx + 1) % self.log_interval == 0:
batch_idx = log_step if log_step is not None else batch_idx
@@ -700,6 +661,7 @@
f"rank{self.rank}_lr/{tag}": lr,
}
+ writer = self.writer
if writer is not None:
writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
@@ -798,3 +760,46 @@
model = model.to(device=kwargs.get("device", "cuda"))
return model
+
+ def warp_optim_scheduler(self, model, **kwargs):
+
+ # optim
+ logging.info("Build optim")
+ optim = kwargs.get("optim", "adam")
+ assert optim in optim_classes
+ optim_class = optim_classes.get(optim)
+ optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
+
+ # scheduler
+ logging.info("Build scheduler")
+ scheduler = kwargs.get("scheduler", "warmuplr")
+ assert scheduler in scheduler_classes
+ scheduler_class = scheduler_classes.get(scheduler)
+ scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
+
+ if use_deepspeed:
+ deepspeed_config = kwargs.get("deepspeed_config", "")
+ with open(deepspeed_config, "r") as fin:
+ ds_configs = json.load(fin)
+ if "optimizer" in ds_configs:
+ # NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
+ # extremely useful when enable cpu_offload, DeepspeedCpuAdam
+ # could be 4~5x faster than torch native adam
+ deepspeed_config = None
+ if "scheduler" in ds_configs:
+ scheduler = None
+ else:
+
+ def scheduler(opt):
+ return scheduler_class(opt, **kwargs.get("scheduler_conf"))
+
+ args = OmegaConf.create({"deepspeed_config": deepspeed_config})
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=args,
+ model=model,
+ optimizer=optim,
+ lr_scheduler=scheduler,
+ model_parameters=model.parameters(),
+ )
+
+ return model, optim, scheduler
--
Gitblit v1.9.1