From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)
---
funasr/bin/train.py | 24 +++++++++++++++---------
1 files changed, 15 insertions(+), 9 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index eb1611a..448e464 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -99,7 +99,7 @@
if freeze_param is not None:
if "," in freeze_param:
freeze_param = eval(freeze_param)
- if isinstance(freeze_param, Sequence):
+ if not isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
@@ -107,8 +107,9 @@
if k.startswith(t + ".") or k == t:
logging.info(f"Setting {k}.requires_grad = False")
p.requires_grad = False
+ if local_rank == 0:
+ logging.info(f"{model_summary(model)}")
- logging.info(f"model info: {model_summary(model)}")
if use_ddp:
model = model.cuda(local_rank)
model = DDP(
@@ -145,8 +146,6 @@
else:
model = model.to(device=kwargs.get("device", "cuda"))
- if local_rank == 0:
- logging.info(f"{model}")
kwargs["device"] = next(model.parameters()).device
# optim
@@ -182,7 +181,12 @@
scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
- trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+ trainer.resume_checkpoint(
+ model=model,
+ optim=optim,
+ scheduler=scheduler,
+ scaler=scaler,
+ )
tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
os.makedirs(tensorboard_dir, exist_ok=True)
@@ -197,8 +201,11 @@
for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
time1 = time.perf_counter()
- for data_split_i in range(dataloader.data_split_num):
- dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
+ for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+ dataloader_tr, dataloader_val = dataloader.build_iter(
+ epoch, data_split_i=data_split_i, start_step=trainer.start_step
+ )
+ trainer.start_step = 0
trainer.train_epoch(
model=model,
optim=optim,
@@ -211,9 +218,8 @@
data_split_i=data_split_i,
data_split_num=dataloader.data_split_num,
)
-
- torch.cuda.empty_cache()
+ torch.cuda.empty_cache()
trainer.validate_epoch(
model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
--
Gitblit v1.9.1