From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/train_utils/trainer.py |  282 +++++++++++++++++++++++++++++++++++++-------------------
 1 files changed, 187 insertions(+), 95 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 8f20ba4..3e69985 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -15,6 +15,11 @@
 from funasr.train_utils.average_nbest_models import average_checkpoints
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 
+try:
+    import wandb
+except:
+    wandb = None
+
 
 @contextmanager
 def maybe_autocast(enabled):
@@ -80,7 +85,12 @@
         self.batch_total = 0
         self.use_fp16 = use_fp16
         self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
-        self.validate_interval = kwargs.get("validate_interval", 5000)
+        self.validate_interval = kwargs.get("validate_interval", -1)
+        if self.validate_interval < 0:
+            self.validate_interval = self.save_checkpoint_interval
+        assert (
+            self.save_checkpoint_interval == self.validate_interval
+        ), f"save_checkpoint_interval must equal to validate_interval"
         self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
         self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
         self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
@@ -105,11 +115,25 @@
         self.saved_ckpts = {}
         self.step_or_epoch = -1
         self.best_step_or_epoch = ""
-        self.val_acc_step_or_eoch = {}
-        self.val_loss_step_or_eoch = {}
-        
-        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+        self.val_acc_step_or_epoch = {}
+        self.val_loss_step_or_epoch = {}
 
+        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+        self.start_data_split_i = 0
+        self.start_step = 0
+        self.step_in_epoch = 0
+        self.use_wandb = kwargs.get("use_wandb", False)
+        if self.use_wandb:
+            wandb.login(key=kwargs.get("wandb_token"))
+            wandb.init(
+                config=kwargs,
+                project=kwargs.get("wandb_project", "my_project"),
+                entity=kwargs.get("wandb_team", "my_team"),
+                name=kwargs.get("wandb_exp_name", "my_exp"),
+                dir=output_dir,
+                job_type="training",
+                reinit=True,
+            )
 
     def save_checkpoint(
         self,
@@ -119,6 +143,8 @@
         optim=None,
         scheduler=None,
         scaler=None,
+        step_in_epoch=None,
+        **kwargs,
     ):
         """
         Saves a checkpoint containing the model's state, the optimizer's state,
@@ -129,25 +155,36 @@
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
 
+        step_in_epoch = None if step is None else step_in_epoch
         if self.rank == 0:
             logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
             # self.step_or_epoch += 1
             state = {
                 "epoch": epoch,
+                "step": step,
+                "total_step": self.batch_total,
                 "state_dict": model.state_dict(),
                 "optimizer": optim.state_dict(),
                 "scheduler": scheduler.state_dict(),
                 "saved_ckpts": self.saved_ckpts,
-                "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
-                "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+                "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+                "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
+                "step_in_epoch": step_in_epoch,
+                "data_split_i": kwargs.get("data_split_i", 0),
+                "data_split_num": kwargs.get("data_split_num", 1),
+                "batch_total": self.batch_total,
+                "train_loss_avg": kwargs.get("train_loss_avg", 0),
+                "train_acc_avg": kwargs.get("train_acc_avg", 0),
             }
+            step = step_in_epoch
             if hasattr(model, "module"):
                 state["state_dict"] = model.module.state_dict()
 
             if scaler:
                 state["scaler_state"] = scaler.state_dict()
+
             # Create output directory if it does not exist
             os.makedirs(self.output_dir, exist_ok=True)
             if step is None:
@@ -156,47 +193,48 @@
                 ckpt_name = f"model.pt.ep{epoch}.{step}"
             filename = os.path.join(self.output_dir, ckpt_name)
             torch.save(state, filename)
+            logging.info(f"Checkpoint saved to {filename}")
 
-            logging.info(f"\nCheckpoint saved to {filename}\n")
             latest = Path(os.path.join(self.output_dir, f"model.pt"))
             torch.save(state, latest)
+
             if self.best_step_or_epoch == "":
                 self.best_step_or_epoch = ckpt_name
 
             if self.avg_keep_nbest_models_type == "acc":
                 if (
-                    self.val_acc_step_or_eoch[ckpt_name]
-                    >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+                    self.val_acc_step_or_epoch[ckpt_name]
+                    >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                     torch.save(state, best_ckpt)
                     logging.info(
-                        f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}"
+                        f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             elif self.avg_keep_nbest_models_type == "loss":
                 if (
-                    self.val_loss_step_or_eoch[ckpt_name]
-                    <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+                    self.val_loss_step_or_epoch[ckpt_name]
+                    <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                     torch.save(state, best_ckpt)
                     logging.info(
-                        f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}"
+                        f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             else:
                 print("Undo")
             self.saved_ckpts[ckpt_name] = getattr(
-                self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+                self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
             )[ckpt_name]
             if self.keep_nbest_models > 0:
                 if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -232,7 +270,7 @@
             ckpt = os.path.join(self.output_dir, "model.pt")
             if os.path.isfile(ckpt):
                 checkpoint = torch.load(ckpt, map_location="cpu")
-                self.start_epoch = checkpoint["epoch"] + 1
+                self.start_epoch = checkpoint["epoch"]
                 # self.model.load_state_dict(checkpoint['state_dict'])
                 src_state = checkpoint["state_dict"]
                 dst_state = model.state_dict()
@@ -243,6 +281,7 @@
                         k_ddp = k.replace("module.", "", 1)
                     else:
                         k_ddp = k
+
                     if k_ddp in src_state.keys():
                         dst_state[k] = src_state[k_ddp]
                     else:
@@ -255,18 +294,35 @@
                     scaler.load_state_dict(checkpoint["scaler_state"])
 
                 self.saved_ckpts = checkpoint["saved_ckpts"]
-                self.val_acc_step_or_eoch = (
-                    checkpoint["val_acc_step_or_eoch"]
-                    if "val_acc_step_or_eoch" in checkpoint
+                self.val_acc_step_or_epoch = (
+                    checkpoint["val_acc_step_or_epoch"]
+                    if "val_acc_step_or_epoch" in checkpoint
                     else {}
                 )
-                self.val_loss_step_or_eoch = (
-                    checkpoint["val_loss_step_or_eoch"]
-                    if "val_loss_step_or_eoch" in checkpoint
+                self.val_loss_step_or_epoch = (
+                    checkpoint["val_loss_step_or_epoch"]
+                    if "val_loss_step_or_epoch" in checkpoint
                     else {}
                 )
                 self.best_step_or_epoch = (
                     checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
+                )
+                self.start_data_split_i = (
+                    checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
+                )
+                self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
+                self.start_step = checkpoint["step"] if "step" in checkpoint else 0
+                self.start_step = 0 if self.start_step is None else self.start_step
+                self.step_in_epoch = (
+                    checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
+                )
+                self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
+                print(checkpoint["train_acc_avg"])
+                self.train_acc_avg = (
+                    checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
+                )
+                self.train_loss_avg = (
+                    checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
                 )
                 model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
@@ -295,7 +351,7 @@
         """
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
+        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
         model.train()
 
         # Set the number of steps for gradient accumulation
@@ -310,11 +366,12 @@
         time_beg = time.perf_counter()
         time5 = time_beg
         for batch_idx, batch in enumerate(dataloader_train):
-            if self.use_ddp or self.use_fsdp:
-                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
-                if iterator_stop > 0:
-                    break
+            # if self.use_ddp or self.use_fsdp:
+            #     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+            #     if iterator_stop > 0:
+            #         break
             self.batch_total += 1
+            self.step_in_epoch += 1
             time1 = time.perf_counter()
             speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
 
@@ -327,15 +384,13 @@
                 time2 = time.perf_counter()
                 with maybe_autocast(self.use_fp16):
                     retval = model(**batch)
-                    
-                    if (
-                        self.reset_gpu_cache
-                        and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
-                    ):
-                        torch.cuda.empty_cache()
 
-                time3 = time.perf_counter()
-                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+                    # if (
+                    #     self.reset_gpu_cache
+                    #     and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
+                    # ):
+                    #     torch.cuda.empty_cache()
+
                 loss, stats, weight = retval
                 stats = {k: v for k, v in stats.items() if v is not None}
                 if self.use_ddp or self.use_fsdp:
@@ -350,33 +405,28 @@
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
+                # loss *= self.world_size
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss / accum_grad
+
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 if self.use_fp16:
                     scaler.scale(loss).backward()
                 else:
                     loss.backward()
                 time4 = time.perf_counter()
-                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+                speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
 
                 self.train_loss_avg = (
-                    self.train_loss_avg * batch_idx + loss.detach().cpu().item()
-                ) / (batch_idx + 1)
+                    self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
+                    + loss.detach().cpu().item()
+                ) / (batch_idx + kwargs.get("start_step", 0) + 1)
                 if "acc" in stats:
                     self.train_acc_avg = (
-                        self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
-                    ) / (batch_idx + 1)
-                if self.use_ddp or self.use_fsdp:
-                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
-                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
-                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                        self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
+                        + stats["acc"].detach().cpu().item()
+                    ) / (batch_idx + kwargs.get("start_step", 0) + 1)
 
             # Perform an optimizer step only after accumulating enough gradients
             if (batch_idx + 1) % accum_grad == 0:
@@ -405,8 +455,22 @@
                 scheduler.step()
                 # Clear gradients for the next accumulation stage
                 optim.zero_grad(set_to_none=True)
-                total_time = f"{time.perf_counter() - time5:0.3f}"
+
+                if self.use_ddp or self.use_fsdp:
+                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
+                total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
                 time5 = time.perf_counter()
+
                 speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
 
                 speed_stats["total_time"] = total_time
@@ -417,9 +481,11 @@
                 self.log(
                     epoch,
                     batch_idx,
+                    log_step=batch_idx + kwargs.get("start_step", 0),
+                    step_in_epoch=self.step_in_epoch,
                     batch_num_epoch=batch_num_epoch,
                     lr=lr,
-                    loss=loss.detach().cpu().item(),
+                    loss=accum_grad * loss.detach().cpu().item(),
                     speed_stats=speed_stats,
                     stats=stats,
                     writer=writer,
@@ -428,16 +494,17 @@
                     data_split_num=kwargs.get("data_split_num", 1),
                 )
 
-            if (batch_idx + 1) % self.validate_interval == 0:
+            if self.step_in_epoch % self.validate_interval == 0:
                 self.validate_epoch(
                     model=model,
                     dataloader_val=dataloader_val,
                     epoch=epoch,
                     writer=writer,
                     step=batch_idx + 1,
+                    step_in_epoch=self.step_in_epoch,
                 )
 
-            if (batch_idx + 1) % self.save_checkpoint_interval == 0:
+            if self.step_in_epoch % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(
                     epoch,
                     model=model,
@@ -445,17 +512,22 @@
                     scheduler=scheduler,
                     scaler=scaler,
                     step=batch_idx + 1,
+                    step_in_epoch=self.step_in_epoch,
+                    data_split_i=kwargs.get("data_split_i", 0),
+                    data_split_num=kwargs.get("data_split_num", 1),
+                    train_loss_avg=self.train_loss_avg,
+                    train_acc_avg=self.train_acc_avg,
                 )
 
             time_beg = time.perf_counter()
-        else:
-            if self.use_ddp or self.use_fsdp:
-                iterator_stop.fill_(1)
-                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+        # else:
+        #     if self.use_ddp or self.use_fsdp:
+        #         iterator_stop.fill_(1)
+        #         dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-            iterator_stop = torch.tensor(0).to(self.device)
+            # iterator_stop = torch.tensor(0).to(self.device)
 
     def validate_epoch(
         self,
@@ -474,7 +546,7 @@
         """
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
+        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
         model.eval()
 
         with torch.no_grad():
@@ -491,12 +563,14 @@
                 time1 = time.perf_counter()
                 speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                 batch = to_device(batch, self.device)
+
                 time2 = time.perf_counter()
                 retval = model(**batch)
                 time3 = time.perf_counter()
                 speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval
                 stats = {k: v for k, v in stats.items() if v is not None}
+
                 if self.use_ddp or self.use_fsdp:
                     # Apply weighted averaging for loss and stats
                     loss = (loss * weight.type(loss.dtype)).sum()
@@ -509,28 +583,33 @@
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
+
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss
                 time4 = time.perf_counter()
 
-                self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
-                    batch_idx + 1
-                )
-                if "acc" in stats:
-                    self.val_acc_avg = (
-                        self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+                if torch.isfinite(loss):
+                    self.val_loss_avg = (
+                        self.val_loss_avg * batch_idx + loss.detach().cpu().item()
                     ) / (batch_idx + 1)
-                if self.use_ddp or self.use_fsdp:
-                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
-                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
-                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
+                    if "acc" in stats:
+                        self.val_acc_avg = (
+                            self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+                        ) / (batch_idx + 1)
+
+                    if self.use_ddp or self.use_fsdp:
+                        val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
+                            self.device
+                        )
+                        val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
+                            self.device
+                        )
+                        dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                        dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                        self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                        self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
                 time5 = time.perf_counter()
                 batch_num_epoch = 1
                 if hasattr(dataloader_val, "__len__"):
@@ -552,12 +631,12 @@
                     iterator_stop.fill_(1)
                     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
 
-        if kwargs.get("step", None) is None:
+        if kwargs.get("step_in_epoch", None) is None:
             ckpt_name = f"model.pt.ep{epoch}"
         else:
-            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
-        self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
-        self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
+        self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
+        self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
         model.train()
 
         if self.use_ddp or self.use_fsdp:
@@ -568,6 +647,7 @@
         self,
         epoch=0,
         batch_idx=0,
+        step_in_epoch=0,
         batch_num_epoch=-1,
         lr=0.0,
         loss=0.0,
@@ -577,11 +657,12 @@
         tag="train",
         data_split_i=0,
         data_split_num=1,
+        log_step=None,
         **kwargs,
     ):
 
         if (batch_idx + 1) % self.log_interval == 0:
-
+            batch_idx = log_step if log_step is not None else batch_idx
             gpu_info = (
                 "GPU, memory: usage: {:.3f} GB, "
                 "peak: {:.3f} GB, "
@@ -598,14 +679,14 @@
             acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
             description = (
                 f"{tag}, "
-                f"rank: {self.local_rank}, "
+                f"rank: {self.rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
                 f"data_slice: {data_split_i}/{data_split_num}, "
-                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
+                f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
                 f"(loss_avg_rank: {loss:.3f}), "
-                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
-                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "
-                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
+                f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
+                f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
                 f"(lr: {lr:.3e}), "
                 f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                 f"{speed_stats}, "
@@ -613,18 +694,29 @@
             )
             logging.info(description)
 
+            description_dict = {
+                f"rank{self.rank}_loss/{tag}": loss,
+                f"rank{self.rank}_lr/{tag}": lr,
+            }
+
             if writer is not None:
-                writer.add_scalar(f"rank{self.local_rank}_loss/{tag}", loss, self.batch_total)
-                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
-                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
+                writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
+                writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
                 for key, var in stats.items():
                     writer.add_scalar(
-                        f"stats_rank{self.local_rank}_{key}/{tag}", var.item(), self.batch_total
+                        f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
                     )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
                 for key, var in speed_stats.items():
                     writer.add_scalar(
-                        f"stats_rank{self.local_rank}_{key}/{tag}", eval(var), self.batch_total
+                        f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
                     )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
+            if self.use_wandb and wandb is not None:
+                wandb.log(
+                    description_dict,
+                    step=self.batch_total,
+                )
 
     def close(self, writer=None):
 

--
Gitblit v1.9.1