From 16a976a01d110d3969759be7720cae2b6b0664f7 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 24 三月 2024 01:27:08 +0800
Subject: [PATCH] finetune
---
funasr/train_utils/trainer.py | 80 ++++++++++++++++++++--------------------
funasr/bin/train.py | 3 +
2 files changed, 42 insertions(+), 41 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 0ff4ba1..5cf54da 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -128,7 +128,8 @@
else:
model = model.to(device=kwargs.get("device", "cuda"))
- logging.info(f"{model}")
+ if local_rank == 0:
+ logging.info(f"{model}")
kwargs["device"] = next(model.parameters()).device
# optim
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 23c18d9..cf23483 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -239,6 +239,8 @@
Args:
epoch (int): The current epoch number.
"""
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
model.train()
@@ -248,15 +250,14 @@
optim.zero_grad()
speed_stats = {}
time5 = time.perf_counter()
- iterator_stop = torch.tensor(0).to(self.device)
- dist.barrier()
- print(f"before iter, iterator_stop: {iterator_stop}\n")
+ # iterator_stop = torch.tensor(0).to(self.device)
+
dataloader_train.batch_sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(dataloader_train):
- if self.use_ddp or self.use_fsdp:
- dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
- if iterator_stop > 0:
- break
+ # if self.use_ddp or self.use_fsdp:
+ # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ # if iterator_stop > 0:
+ # break
self.batch_total += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time5:0.3f}"
@@ -297,13 +298,13 @@
self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
if "acc" in stats:
self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
- # if self.use_ddp or self.use_fsdp:
- # train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
- # train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
- # dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
- # dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
- # self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
- # self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+ if self.use_ddp or self.use_fsdp:
+ train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+ train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+ self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+ self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
# Perform an optimizer step only after accumulating enough gradients
@@ -363,10 +364,10 @@
if (batch_idx+1) % self.save_checkpoint_interval == 0:
self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
- else:
- if self.use_ddp or self.use_fsdp:
- iterator_stop.fill_(1)
- dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ # else:
+ # if self.use_ddp or self.use_fsdp:
+ # iterator_stop.fill_(1)
+ # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if self.use_ddp or self.use_fsdp:
dist.barrier()
@@ -387,6 +388,8 @@
Args:
epoch (int): The current epoch number.
"""
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
model.eval()
@@ -394,16 +397,15 @@
speed_stats = {}
time5 = time.perf_counter()
- iterator_stop = torch.tensor(0).to(self.device)
- dist.barrier()
- print(f"before iter, iterator_stop: {iterator_stop}\n")
+ # iterator_stop = torch.tensor(0).to(self.device)
+
for batch_idx, batch in enumerate(dataloader_val):
- if self.use_ddp or self.use_fsdp:
- dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
- if epoch >= 1:
- print(f"iterator_stop: {iterator_stop}\n")
- if iterator_stop > 0:
- break
+ # if self.use_ddp or self.use_fsdp:
+ # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ # if epoch >= 1:
+ # print(f"iterator_stop: {iterator_stop}\n")
+ # if iterator_stop > 0:
+ # break
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
@@ -432,13 +434,13 @@
self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
if "acc" in stats:
self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
- # if self.use_ddp or self.use_fsdp:
- # val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
- # val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
- # dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
- # dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
- # self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
- # self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
batch_num_epoch = -1
if hasattr(dataloader_val, "__len__"):
@@ -453,15 +455,13 @@
tag="val",
)
- else:
- if self.use_ddp or self.use_fsdp:
- iterator_stop.fill_(1)
- dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ # else:
+ # if self.use_ddp or self.use_fsdp:
+ # iterator_stop.fill_(1)
+ # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
self.val_acc_list.append(self.val_acc_avg)
model.train()
-
-
if self.use_ddp or self.use_fsdp:
dist.barrier()
--
Gitblit v1.9.1