zhifu gao
2024-03-21 3ac03e448b7673604eb86f619b27521fca55f34d
funasr/train_utils/trainer_llm.py
@@ -1,3 +1,4 @@
import math
import os
import time
import torch
@@ -61,6 +62,8 @@
        """
        
        self.output_dir = output_dir
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)
        self.resume = kwargs.get('resume', True)
        self.start_epoch = 0
        self.max_epoch = kwargs.get('max_epoch', 100)
@@ -78,6 +81,7 @@
        # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
        # self.scaler = scaler
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
        self.accum_grad = kwargs.get("accum_grad", 1)
        self.grad_clip = kwargs.get("grad_clip", 10.0)
        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
@@ -93,6 +97,15 @@
            logging.warning("distributed is not initialized, only single shard")
        self.rank = rank
        self.world_size = world_size
        self.train_acc_avg = 0.0
        self.train_loss_avg = 0.0
        self.val_acc_avg = 0.0
        self.val_loss_avg = 0.0
        self.best_acc_idx = 0
        self.saved_ckpts = {}
        self.val_acc_list = []
        self.step_or_epoch = -1
        
        
@@ -112,28 +125,56 @@
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        if self.rank == 0:
            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            self.step_or_epoch += 1
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optim.state_dict(),
                'scheduler': scheduler.state_dict(),
                "acc": self.val_acc_list,
                "step_or_epoch": self.step_or_epoch,
            }
            if hasattr(model, "module"):
                state["state_dict"] = model.module.state_dict()
            if scaler:
                state["scaler_state"] = scaler.state_dict()
            # Create output directory if it does not exist
            os.makedirs(self.output_dir, exist_ok=True)
            if step is None:
                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
                ckpt_name = f'model.pt.ep{epoch}'
            else:
                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
                ckpt_name = f'model.pt.ep{epoch}.{step}'
            filename = os.path.join(self.output_dir, ckpt_name)
            torch.save(state, filename)
            
            print(f'\nCheckpoint saved to {filename}\n')
            logging.info(f'\nCheckpoint saved to {filename}\n')
            latest = Path(os.path.join(self.output_dir, f'model.pt'))
            torch.save(state, latest)
            if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
                self.best_acc_idx = self.step_or_epoch
                best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
                torch.save(state, best_ckpt)
                logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
            else:
                logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
            if self.keep_nbest_models > 0:
                self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
                if len(self.saved_ckpts) > self.keep_nbest_models:
                    min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
                    if min_key in self.saved_ckpts:
                        del self.saved_ckpts[min_key]
                    filename = os.path.join(self.output_dir, min_key)
                    logging.info(f"Delete: {filename}")
                    if os.path.exists(filename):
                        os.remove(filename)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    
@@ -173,6 +214,10 @@
                scheduler.load_state_dict(checkpoint['scheduler'])
                if scaler is not None and 'scaler_state' in checkpoint:
                    scaler.load_state_dict(checkpoint['scaler_state'])
                self.val_acc_list = checkpoint["acc"]
                self.step_or_epoch = checkpoint["step_or_epoch"]
                print(f"Checkpoint loaded successfully from '{ckpt}'")
            else:
                print(f"No checkpoint found at '{ckpt}', does not resume status!")
@@ -180,52 +225,7 @@
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        
    # def train(self):
    #     """
    #     Starts the training process, iterating over epochs, training the model,
    #     and saving checkpoints at the end of each epoch.
    #     """
    #     if self.resume:
    #         self.resume_checkpoint(self.output_dir)
    #
    #     for epoch in range(self.start_epoch, self.max_epoch + 1):
    #         time1 = time.perf_counter()
    #         self.train_epoch(epoch)
    #
    #
    #
    #         if self.use_ddp or self.use_fsdp:
    #             dist.barrier()
    #
    #         self._validate_epoch(epoch)
    #
    #         if self.use_ddp or self.use_fsdp:
    #             dist.barrier()
    #
    #
    #         if self.rank == 0:
    #             self._save_checkpoint(epoch)
    #
    #         if self.use_ddp or self.use_fsdp:
    #             dist.barrier()
    #
    #         self.scheduler.step()
    #
    #         time2 = time.perf_counter()
    #         time_escaped = (time2 - time1)/3600.0
    #         print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
    #
    #     if self.rank == 0:
    #         average_checkpoints(self.output_dir, self.avg_nbest_model)
    #
    #     if self.use_ddp or self.use_fsdp:
    #         dist.barrier()
    #
    #
    #     if writer:
    #         writer.close()
    #
    def train_epoch(self,
                model=None,
                optim=None,
@@ -241,9 +241,9 @@
        Args:
            epoch (int): The current epoch number.
        """
        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
        model.train()
        # Set the number of steps for gradient accumulation
        accum_grad = self.accum_grad
        # Initialize the gradient accumulation
@@ -288,6 +288,18 @@
                    loss.backward()
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                if "acc" in stats:
                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                if self.use_ddp or self.use_fsdp:
                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
            
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accum_grad == 0:
@@ -322,9 +334,11 @@
    
                speed_stats["total_time"] = total_time
                lr = scheduler.get_last_lr()[0]
                batch_num_epoch = -1
                if hasattr(dataloader_train, "__len__"):
                    batch_num_epoch = len(dataloader_train)
                self.log(epoch, batch_idx,
                         batch_num_epoch=len(dataloader_train),
                         batch_num_epoch=batch_num_epoch,
                         lr=lr,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
@@ -341,7 +355,7 @@
                    writer=writer
                )
            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
        
@@ -364,6 +378,7 @@
        Args:
            epoch (int): The current epoch number.
        """
        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
        model.eval()
        
        with torch.no_grad():
@@ -394,18 +409,35 @@
                loss = loss
                time4 = time.perf_counter()
                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                if "acc" in stats:
                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                if self.use_ddp or self.use_fsdp:
                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                
                batch_num_epoch = -1
                if hasattr(dataloader_val, "__len__"):
                    batch_num_epoch = len(dataloader_val)
                self.log(epoch, batch_idx,
                         batch_num_epoch=len(dataloader_val),
                         batch_num_epoch=batch_num_epoch,
                         lr=0.0,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
                         stats=stats,
                         writer=writer,
                         tag="train",
                         tag="val",
                         )
        self.val_acc_list.append(self.val_acc_avg)
        model.train()
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        
        
    def log(self,
@@ -422,39 +454,47 @@
        
        if (batch_idx + 1) % self.log_interval == 0:
            
            gpu_info = "GPU, memory: {:.3f} GB, " \
                       "{:.3f} GB, " \
                       "{:.3f} GB, " \
                       "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
                       "peak: {:.3f} GB, " \
                       "cache: {:.3f} GB, " \
                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
                                          )
            
            time_now = datetime.now()
            time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
            description = (
                f"{time_now}, "
                f"{tag}, "
                f"rank: {self.local_rank}, "
                f"epoch: {epoch}/{self.max_epoch}, "
                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
                f"(loss: {loss:.3f}), "
                f"(loss_avg_rank: {loss:.3f}), "
                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
                f"(lr: {lr:.3e}), "
                f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                f"{speed_stats}, "
                f"{gpu_info}"
            )
            logging.info(description)
            
            if writer is not None:
                writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                for key, var in stats.items():
                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                for key, var in speed_stats.items():
                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
        
    def close(self, writer=None):
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        if writer is not None:
            writer.close()