From 2ac38adbe5f4e1374a079e032ed4b504351a207c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 四月 2024 18:08:57 +0800
Subject: [PATCH] Dev gzf exp (#1647)

---
 funasr/train_utils/trainer.py |   39 ++++++++++++++++++++++++++-------------
 1 files changed, 26 insertions(+), 13 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 116c9e3..3ee6885 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -79,7 +79,7 @@
         self.validate_interval = kwargs.get("validate_interval", 5000)
         self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
         self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
-        self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
+        self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
         self.accum_grad = kwargs.get("accum_grad", 1)
         self.grad_clip = kwargs.get("grad_clip", 10.0)
         self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
@@ -134,7 +134,7 @@
                 "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
                 "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
-                "avg_keep_nbest_models_type": slef.avg_keep_nbest_models_type,
+                "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
             }
             if hasattr(model, "module"):
                 state["state_dict"] = model.module.state_dict()
@@ -161,17 +161,17 @@
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
                     torch.save(state, best_ckpt)
-                    logging.info(f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]}, {best_ckpt}")
+                    logging.info(f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}")
                 else:
-                    logging.info(f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]}")
+                    logging.info(f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}")
             elif self.avg_keep_nbest_models_type == "loss":
                 if self.val_loss_step_or_eoch[ckpt_name] <= self.val_loss_step_or_eoch[self.best_step_or_epoch]:
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
                     torch.save(state, best_ckpt)
-                    logging.info(f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]}, {best_ckpt}")
+                    logging.info(f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}")
                 else:
-                    logging.info(f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]}")
+                    logging.info(f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}")
             else:
                 print("Undo")
             self.saved_ckpts[ckpt_name] = getattr(self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch")[ckpt_name]
@@ -233,7 +233,7 @@
                 self.saved_ckpts = checkpoint["saved_ckpts"]
                 self.val_acc_step_or_eoch = checkpoint["val_acc_step_or_eoch"] if "val_acc_step_or_eoch" in checkpoint else {}
                 self.val_loss_step_or_eoch = checkpoint["val_loss_step_or_eoch"] if "val_loss_step_or_eoch" in checkpoint else {}
-                self.val_loss_step_or_eoch = checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
+                self.best_step_or_epoch = checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
                 model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
             else:
@@ -252,6 +252,7 @@
                 dataloader_val=None,
                 epoch=None,
                 writer=None,
+                **kwargs,
                     ):
         """
         Defines the training process for a single epoch with gradient accumulation.
@@ -268,10 +269,12 @@
         # Initialize the gradient accumulation
         optim.zero_grad()
         speed_stats = {}
-        time5 = time.perf_counter()
+        
         iterator_stop = torch.tensor(0).to(self.device)
 
         dataloader_train.batch_sampler.set_epoch(epoch)
+        time_beg = time.perf_counter()
+        time5 = time_beg
         for batch_idx, batch in enumerate(dataloader_train):
             if self.use_ddp or self.use_fsdp:
                 dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
@@ -279,11 +282,13 @@
                     break
             self.batch_total += 1
             time1 = time.perf_counter()
-            speed_stats["data_load"] = f"{time1-time5:0.3f}"
+            speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
 
             batch = to_device(batch, self.device)
-            
-            my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+
+            my_context = nullcontext
+            if self.use_ddp or self.use_fsdp:
+                my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
             with my_context():
                 time2 = time.perf_counter()
                 with maybe_autocast(self.use_fp16):
@@ -370,6 +375,8 @@
                          stats=stats,
                          writer=writer,
                          tag="train",
+                         data_split_i=kwargs.get("data_split_i", 0),
+                         data_split_num=kwargs.get("data_split_num", 1),
                          )
 
             if (batch_idx + 1) % self.validate_interval == 0:
@@ -377,12 +384,14 @@
                     model=model,
                     dataloader_val=dataloader_val,
                     epoch=epoch,
-                    writer=writer
+                    writer=writer,
+                    step=batch_idx+1,
                 )
 
             if (batch_idx+1) % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
 
+            time_beg = time.perf_counter()
         else:
             if self.use_ddp or self.use_fsdp:
                 iterator_stop.fill_(1)
@@ -501,6 +510,9 @@
             stats=None,
             writer=None,
             tag="train",
+            data_split_i=0,
+            data_split_num=1,
+            **kwargs,
             ):
         
         if (batch_idx + 1) % self.log_interval == 0:
@@ -520,10 +532,11 @@
                 f"{tag}, "
                 f"rank: {self.local_rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
+                f"data_slice: {data_split_i}/{data_split_num}, "
                 f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
                 f"(loss_avg_rank: {loss:.3f}), "
                 f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
-                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
+                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "
                 f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
                 f"(lr: {lr:.3e}), "
                 f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "

--
Gitblit v1.9.1