From b85e140c3e4a7a7ccba59abfc67b63aac7a28dd9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 17:28:56 +0800
Subject: [PATCH] update
---
funasr/train_utils/trainer_ds.py | 60 +++++++++++++++++++++++-------------------------------------
1 files changed, 23 insertions(+), 37 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index f3a0309..8afbc6d 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -78,7 +78,7 @@
self.world_size = world_size
self.use_ddp = use_ddp
self.use_fsdp = use_fsdp
- self.use_deepspeed = use_deepspeed
+
self.device = kwargs.get("device", "cuda")
self.output_dir = output_dir
@@ -136,6 +136,9 @@
self.writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
self.writer = None
+
+ self.use_deepspeed = use_deepspeed
+ self.deepspeed_config = kwargs.get("deepspeed_config", "")
def save_checkpoint(
self,
@@ -394,8 +397,8 @@
time4 = time.perf_counter()
loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
- self.update_step(model, optim, scheduler, scaler, loss_dict)
- total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
+ self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict)
+ total_time = f"{(time.perf_counter() - time5):0.3f}"
time5 = time.perf_counter()
loss_dict["speed_stats"]["optim_time"] = f"{time5 - time4:0.3f}"
@@ -412,7 +415,7 @@
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
- writer=writer,
+ writer=self.writer,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
)
@@ -443,7 +446,8 @@
iterator_stop = torch.tensor(0).to(self.device)
def forward_step(self, model, batch, loss_dict={}):
- with maybe_autocast(self.use_fp16):
+ dtype = torch.bfloat16
+ with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
retval = model(**batch)
loss, stats, weight = retval
@@ -465,8 +469,8 @@
else:
loss.backward()
- def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
-
+ def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
+ batch_idx = loss_dict["batch_idx"]
if self.use_deepspeed:
model.step()
else:
@@ -613,7 +617,7 @@
loss = loss_dict["loss"].detach().cpu().item()
epoch = loss_dict["epoch"]
batch_idx = loss_dict["batch_idx"]
- step_in_epoch = loss_dict["step_in_epoch"]
+ step_in_epoch = self.step_in_epoch
batch_num_epoch = loss_dict["batch_num_epoch"]
lr = loss_dict["lr"]
@@ -732,36 +736,17 @@
"find_unused_parameters", False
),
)
- # elif self.use_fsdp:
- # # model = FSDP(model).cuda(local_rank)
- #
- # def custom_auto_wrap_policy(
- # module: nn.Module,
- # recurse: bool,
- # nonwrapped_numel: int,
- # # Additional custom arguments
- # min_num_params: int = int(1e8),
- # ) -> bool:
- # # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
- # is_large = unwrapped_params >= min_num_params
- # requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
- # return is_large and requires_grad_uniform
- #
- # # Configure a custom `min_num_params`
- # my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
- # torch.cuda.set_device(local_rank)
- # model = FSDP(
- # model,
- # auto_wrap_policy=custom_auto_wrap_policy,
- # mixed_precision=None,
- # device_id=torch.cuda.current_device(),
- # )
+
else:
model = model.to(device=kwargs.get("device", "cuda"))
return model
def warp_optim_scheduler(self, model, **kwargs):
+ from funasr.optimizers import optim_classes
+ from funasr.schedulers import scheduler_classes
+ from omegaconf import OmegaConf, DictConfig
+ import json
# optim
logging.info("Build optim")
@@ -777,15 +762,17 @@
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
- if use_deepspeed:
- deepspeed_config = kwargs.get("deepspeed_config", "")
- with open(deepspeed_config, "r") as fin:
+ if self.use_deepspeed:
+ import deepspeed
+
+ args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
+ with open(self.deepspeed_config, "r") as fin:
ds_configs = json.load(fin)
if "optimizer" in ds_configs:
# NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
# extremely useful when enable cpu_offload, DeepspeedCpuAdam
# could be 4~5x faster than torch native adam
- deepspeed_config = None
+ optim = None
if "scheduler" in ds_configs:
scheduler = None
else:
@@ -793,7 +780,6 @@
def scheduler(opt):
return scheduler_class(opt, **kwargs.get("scheduler_conf"))
- args = OmegaConf.create({"deepspeed_config": deepspeed_config})
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
--
Gitblit v1.9.1