From 3ac03e448b7673604eb86f619b27521fca55f34d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 01:36:39 +0800
Subject: [PATCH] train & finetune llm-asr (#1519)

---
 funasr/train_utils/trainer_llm.py |  178 ++++++++++++++++++++++++++++++++++++-----------------------
 1 files changed, 109 insertions(+), 69 deletions(-)

diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py
index 6a3b83b..5f13b5a 100644
--- a/funasr/train_utils/trainer_llm.py
+++ b/funasr/train_utils/trainer_llm.py
@@ -1,3 +1,4 @@
+import math
 import os
 import time
 import torch
@@ -61,6 +62,8 @@
         """
         
         self.output_dir = output_dir
+        if not os.path.exists(self.output_dir):
+            os.makedirs(self.output_dir, exist_ok=True)
         self.resume = kwargs.get('resume', True)
         self.start_epoch = 0
         self.max_epoch = kwargs.get('max_epoch', 100)
@@ -78,6 +81,7 @@
         # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
         # self.scaler = scaler
         self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+        self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
         self.accum_grad = kwargs.get("accum_grad", 1)
         self.grad_clip = kwargs.get("grad_clip", 10.0)
         self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
@@ -93,6 +97,15 @@
             logging.warning("distributed is not initialized, only single shard")
         self.rank = rank
         self.world_size = world_size
+        self.train_acc_avg = 0.0
+        self.train_loss_avg = 0.0
+        self.val_acc_avg = 0.0
+        self.val_loss_avg = 0.0
+        self.best_acc_idx = 0
+        self.saved_ckpts = {}
+        self.val_acc_list = []
+        self.step_or_epoch = -1
+        
         
 
         
@@ -112,28 +125,56 @@
         Args:
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
+        
         if self.rank == 0:
+            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+            self.step_or_epoch += 1
             state = {
                 'epoch': epoch,
                 'state_dict': model.state_dict(),
                 'optimizer': optim.state_dict(),
                 'scheduler': scheduler.state_dict(),
+                "acc": self.val_acc_list,
+                "step_or_epoch": self.step_or_epoch,
             }
+            if hasattr(model, "module"):
+                state["state_dict"] = model.module.state_dict()
+                
             if scaler:
                 state["scaler_state"] = scaler.state_dict()
             # Create output directory if it does not exist
             os.makedirs(self.output_dir, exist_ok=True)
             if step is None:
-                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+                ckpt_name = f'model.pt.ep{epoch}'
             else:
-                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
-            
+                ckpt_name = f'model.pt.ep{epoch}.{step}'
+            filename = os.path.join(self.output_dir, ckpt_name)
             torch.save(state, filename)
             
-            print(f'\nCheckpoint saved to {filename}\n')
+            logging.info(f'\nCheckpoint saved to {filename}\n')
             latest = Path(os.path.join(self.output_dir, f'model.pt'))
             torch.save(state, latest)
-        
+            
+            if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
+                self.best_acc_idx = self.step_or_epoch
+                best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+                torch.save(state, best_ckpt)
+                logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
+            else:
+                logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
+            
+            if self.keep_nbest_models > 0:
+                self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
+                if len(self.saved_ckpts) > self.keep_nbest_models:
+
+                    min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+                    if min_key in self.saved_ckpts:
+                        del self.saved_ckpts[min_key]
+                    filename = os.path.join(self.output_dir, min_key)
+                    logging.info(f"Delete: {filename}")
+                    if os.path.exists(filename):
+                        os.remove(filename)
+                
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
     
@@ -173,6 +214,10 @@
                 scheduler.load_state_dict(checkpoint['scheduler'])
                 if scaler is not None and 'scaler_state' in checkpoint:
                     scaler.load_state_dict(checkpoint['scaler_state'])
+                
+                self.val_acc_list = checkpoint["acc"]
+                self.step_or_epoch = checkpoint["step_or_epoch"]
+                
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
             else:
                 print(f"No checkpoint found at '{ckpt}', does not resume status!")
@@ -180,52 +225,7 @@
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
         
-    # def train(self):
-    #     """
-    #     Starts the training process, iterating over epochs, training the model,
-    #     and saving checkpoints at the end of each epoch.
-    #     """
-    #     if self.resume:
-    #         self.resume_checkpoint(self.output_dir)
-    #
-    #     for epoch in range(self.start_epoch, self.max_epoch + 1):
-    #         time1 = time.perf_counter()
-    #         self.train_epoch(epoch)
-    #
-    #
-    #
-    #         if self.use_ddp or self.use_fsdp:
-    #             dist.barrier()
-    #
-    #         self._validate_epoch(epoch)
-    #
-    #         if self.use_ddp or self.use_fsdp:
-    #             dist.barrier()
-    #
-    #
-    #         if self.rank == 0:
-    #             self._save_checkpoint(epoch)
-    #
-    #         if self.use_ddp or self.use_fsdp:
-    #             dist.barrier()
-    #
-    #         self.scheduler.step()
-    #
-    #         time2 = time.perf_counter()
-    #         time_escaped = (time2 - time1)/3600.0
-    #         print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
-    #
-    #     if self.rank == 0:
-    #         average_checkpoints(self.output_dir, self.avg_nbest_model)
-    #
-    #     if self.use_ddp or self.use_fsdp:
-    #         dist.barrier()
-    #
-    #
-    #     if writer:
-    #         writer.close()
-    #
-    
+ 
     def train_epoch(self,
                 model=None,
                 optim=None,
@@ -241,9 +241,9 @@
         Args:
             epoch (int): The current epoch number.
         """
+        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
         model.train()
 
-        
         # Set the number of steps for gradient accumulation
         accum_grad = self.accum_grad
         # Initialize the gradient accumulation
@@ -288,6 +288,18 @@
                     loss.backward()
                 time4 = time.perf_counter()
                 speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+                
+                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+                if "acc" in stats:
+                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+                if self.use_ddp or self.use_fsdp:
+                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                
             
             # Perform an optimizer step only after accumulating enough gradients
             if (batch_idx + 1) % accum_grad == 0:
@@ -322,9 +334,11 @@
     
                 speed_stats["total_time"] = total_time
                 lr = scheduler.get_last_lr()[0]
-
+                batch_num_epoch = -1
+                if hasattr(dataloader_train, "__len__"):
+                    batch_num_epoch = len(dataloader_train)
                 self.log(epoch, batch_idx,
-                         batch_num_epoch=len(dataloader_train),
+                         batch_num_epoch=batch_num_epoch,
                          lr=lr,
                          loss=loss.detach().cpu().item(),
                          speed_stats=speed_stats,
@@ -341,7 +355,7 @@
                     writer=writer
                 )
 
-            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
 
         
@@ -364,6 +378,7 @@
         Args:
             epoch (int): The current epoch number.
         """
+        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
         model.eval()
         
         with torch.no_grad():
@@ -394,18 +409,35 @@
                 loss = loss
                 time4 = time.perf_counter()
 
+                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+                if "acc" in stats:
+                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+                if self.use_ddp or self.use_fsdp:
+                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                 
+                batch_num_epoch = -1
+                if hasattr(dataloader_val, "__len__"):
+                    batch_num_epoch = len(dataloader_val)
                 self.log(epoch, batch_idx,
-                         batch_num_epoch=len(dataloader_val),
+                         batch_num_epoch=batch_num_epoch,
                          lr=0.0,
                          loss=loss.detach().cpu().item(),
                          speed_stats=speed_stats,
                          stats=stats,
                          writer=writer,
-                         tag="train",
+                         tag="val",
                          )
 
+        self.val_acc_list.append(self.val_acc_avg)
         model.train()
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         
         
     def log(self,
@@ -422,39 +454,47 @@
         
         if (batch_idx + 1) % self.log_interval == 0:
             
-            gpu_info = "GPU, memory: {:.3f} GB, " \
-                       "{:.3f} GB, " \
-                       "{:.3f} GB, " \
-                       "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
+                       "peak: {:.3f} GB, " \
+                       "cache: {:.3f} GB, " \
+                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
                                           torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
                                           torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
                                           torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
                                           )
             
-            time_now = datetime.now()
-            time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
+            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
             description = (
-                f"{time_now}, "
+                f"{tag}, "
                 f"rank: {self.local_rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
                 f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
-                f"(loss: {loss:.3f}), "
+                f"(loss_avg_rank: {loss:.3f}), "
+                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
+                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
                 f"(lr: {lr:.3e}), "
-                f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
+                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                 f"{speed_stats}, "
                 f"{gpu_info}"
             )
             logging.info(description)
             
             if writer is not None:
-                writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
+                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
+                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                 writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                 for key, var in stats.items():
-                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
+                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                 for key, var in speed_stats.items():
-                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
+                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
         
     def close(self, writer=None):
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        
         if writer is not None:
             writer.close()
     

--
Gitblit v1.9.1