From 3ac03e448b7673604eb86f619b27521fca55f34d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 01:36:39 +0800
Subject: [PATCH] train & finetune llm-asr (#1519)
---
funasr/train_utils/trainer_llm.py | 178 ++++++++++++++++++++++++++++++++++++-----------------------
1 files changed, 109 insertions(+), 69 deletions(-)
diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py
index 6a3b83b..5f13b5a 100644
--- a/funasr/train_utils/trainer_llm.py
+++ b/funasr/train_utils/trainer_llm.py
@@ -1,3 +1,4 @@
+import math
import os
import time
import torch
@@ -61,6 +62,8 @@
"""
self.output_dir = output_dir
+ if not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir, exist_ok=True)
self.resume = kwargs.get('resume', True)
self.start_epoch = 0
self.max_epoch = kwargs.get('max_epoch', 100)
@@ -78,6 +81,7 @@
# scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
# self.scaler = scaler
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+ self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
self.accum_grad = kwargs.get("accum_grad", 1)
self.grad_clip = kwargs.get("grad_clip", 10.0)
self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
@@ -93,6 +97,15 @@
logging.warning("distributed is not initialized, only single shard")
self.rank = rank
self.world_size = world_size
+ self.train_acc_avg = 0.0
+ self.train_loss_avg = 0.0
+ self.val_acc_avg = 0.0
+ self.val_loss_avg = 0.0
+ self.best_acc_idx = 0
+ self.saved_ckpts = {}
+ self.val_acc_list = []
+ self.step_or_epoch = -1
+
@@ -112,28 +125,56 @@
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
+
if self.rank == 0:
+ logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+ self.step_or_epoch += 1
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optim.state_dict(),
'scheduler': scheduler.state_dict(),
+ "acc": self.val_acc_list,
+ "step_or_epoch": self.step_or_epoch,
}
+ if hasattr(model, "module"):
+ state["state_dict"] = model.module.state_dict()
+
if scaler:
state["scaler_state"] = scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
- filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+ ckpt_name = f'model.pt.ep{epoch}'
else:
- filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
-
+ ckpt_name = f'model.pt.ep{epoch}.{step}'
+ filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
- print(f'\nCheckpoint saved to {filename}\n')
+ logging.info(f'\nCheckpoint saved to {filename}\n')
latest = Path(os.path.join(self.output_dir, f'model.pt'))
torch.save(state, latest)
-
+
+ if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
+ self.best_acc_idx = self.step_or_epoch
+ best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+ torch.save(state, best_ckpt)
+ logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
+ else:
+ logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
+
+ if self.keep_nbest_models > 0:
+ self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
+ if len(self.saved_ckpts) > self.keep_nbest_models:
+
+ min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+ if min_key in self.saved_ckpts:
+ del self.saved_ckpts[min_key]
+ filename = os.path.join(self.output_dir, min_key)
+ logging.info(f"Delete: {filename}")
+ if os.path.exists(filename):
+ os.remove(filename)
+
if self.use_ddp or self.use_fsdp:
dist.barrier()
@@ -173,6 +214,10 @@
scheduler.load_state_dict(checkpoint['scheduler'])
if scaler is not None and 'scaler_state' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state'])
+
+ self.val_acc_list = checkpoint["acc"]
+ self.step_or_epoch = checkpoint["step_or_epoch"]
+
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
@@ -180,52 +225,7 @@
if self.use_ddp or self.use_fsdp:
dist.barrier()
- # def train(self):
- # """
- # Starts the training process, iterating over epochs, training the model,
- # and saving checkpoints at the end of each epoch.
- # """
- # if self.resume:
- # self.resume_checkpoint(self.output_dir)
- #
- # for epoch in range(self.start_epoch, self.max_epoch + 1):
- # time1 = time.perf_counter()
- # self.train_epoch(epoch)
- #
- #
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- # self._validate_epoch(epoch)
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- #
- # if self.rank == 0:
- # self._save_checkpoint(epoch)
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- # self.scheduler.step()
- #
- # time2 = time.perf_counter()
- # time_escaped = (time2 - time1)/3600.0
- # print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
- #
- # if self.rank == 0:
- # average_checkpoints(self.output_dir, self.avg_nbest_model)
- #
- # if self.use_ddp or self.use_fsdp:
- # dist.barrier()
- #
- #
- # if writer:
- # writer.close()
- #
-
+
def train_epoch(self,
model=None,
optim=None,
@@ -241,9 +241,9 @@
Args:
epoch (int): The current epoch number.
"""
+ logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
model.train()
-
# Set the number of steps for gradient accumulation
accum_grad = self.accum_grad
# Initialize the gradient accumulation
@@ -288,6 +288,18 @@
loss.backward()
time4 = time.perf_counter()
speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+
+ self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+ if "acc" in stats:
+ self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+ if self.use_ddp or self.use_fsdp:
+ train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+ train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+ self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+ self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
# Perform an optimizer step only after accumulating enough gradients
if (batch_idx + 1) % accum_grad == 0:
@@ -322,9 +334,11 @@
speed_stats["total_time"] = total_time
lr = scheduler.get_last_lr()[0]
-
+ batch_num_epoch = -1
+ if hasattr(dataloader_train, "__len__"):
+ batch_num_epoch = len(dataloader_train)
self.log(epoch, batch_idx,
- batch_num_epoch=len(dataloader_train),
+ batch_num_epoch=batch_num_epoch,
lr=lr,
loss=loss.detach().cpu().item(),
speed_stats=speed_stats,
@@ -341,7 +355,7 @@
writer=writer
)
- if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+ if (batch_idx+1) % self.save_checkpoint_interval == 0:
self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
@@ -364,6 +378,7 @@
Args:
epoch (int): The current epoch number.
"""
+ logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
model.eval()
with torch.no_grad():
@@ -394,18 +409,35 @@
loss = loss
time4 = time.perf_counter()
+ self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+ if "acc" in stats:
+ self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+ batch_num_epoch = -1
+ if hasattr(dataloader_val, "__len__"):
+ batch_num_epoch = len(dataloader_val)
self.log(epoch, batch_idx,
- batch_num_epoch=len(dataloader_val),
+ batch_num_epoch=batch_num_epoch,
lr=0.0,
loss=loss.detach().cpu().item(),
speed_stats=speed_stats,
stats=stats,
writer=writer,
- tag="train",
+ tag="val",
)
+ self.val_acc_list.append(self.val_acc_avg)
model.train()
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
def log(self,
@@ -422,39 +454,47 @@
if (batch_idx + 1) % self.log_interval == 0:
- gpu_info = "GPU, memory: {:.3f} GB, " \
- "{:.3f} GB, " \
- "{:.3f} GB, " \
- "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+ gpu_info = "GPU, memory: usage: {:.3f} GB, " \
+ "peak: {:.3f} GB, " \
+ "cache: {:.3f} GB, " \
+ "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
)
- time_now = datetime.now()
- time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
+ loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+ acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
description = (
- f"{time_now}, "
+ f"{tag}, "
f"rank: {self.local_rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
- f"(loss: {loss:.3f}), "
+ f"(loss_avg_rank: {loss:.3f}), "
+ f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
+ f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
+ f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
f"(lr: {lr:.3e}), "
- f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
+ f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
f"{speed_stats}, "
f"{gpu_info}"
)
logging.info(description)
if writer is not None:
- writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
+ writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
+ writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
for key, var in stats.items():
- writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
+ writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
for key, var in speed_stats.items():
- writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
+ writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
def close(self, writer=None):
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
if writer is not None:
writer.close()
--
Gitblit v1.9.1