From 6e86c5044d30dffe356b6e42838d01b7cfaf4272 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 26 三月 2024 01:11:52 +0800
Subject: [PATCH] docs

---
 funasr/train_utils/trainer.py |   50 +++++++++++++++++++++++++-------------------------
 1 files changed, 25 insertions(+), 25 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index d0023fd..2d47fc1 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -239,6 +239,8 @@
         Args:
             epoch (int): The current epoch number.
         """
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
         model.train()
 
@@ -249,8 +251,7 @@
         speed_stats = {}
         time5 = time.perf_counter()
         iterator_stop = torch.tensor(0).to(self.device)
-        dist.barrier()
-        print(f"before iter, iterator_stop: {iterator_stop}\n")
+
         dataloader_train.batch_sampler.set_epoch(epoch)
         for batch_idx, batch in enumerate(dataloader_train):
             if self.use_ddp or self.use_fsdp:
@@ -297,13 +298,13 @@
                 self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                 if "acc" in stats:
                     self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
-                # if self.use_ddp or self.use_fsdp:
-                #     train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
-                #     train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
-                #     dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
-                #     dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
-                #     self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
-                #     self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                if self.use_ddp or self.use_fsdp:
+                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
                 
             
             # Perform an optimizer step only after accumulating enough gradients
@@ -339,7 +340,7 @@
     
                 speed_stats["total_time"] = total_time
                 lr = scheduler.get_last_lr()[0]
-                batch_num_epoch = -1
+                batch_num_epoch = 1
                 if hasattr(dataloader_train, "__len__"):
                     batch_num_epoch = len(dataloader_train)
                 self.log(epoch, batch_idx,
@@ -370,6 +371,7 @@
                 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
+            iterator_stop = torch.tensor(0).to(self.device)
         
         
 
@@ -387,6 +389,8 @@
         Args:
             epoch (int): The current epoch number.
         """
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
         model.eval()
         
@@ -395,13 +399,10 @@
             speed_stats = {}
             time5 = time.perf_counter()
             iterator_stop = torch.tensor(0).to(self.device)
-            dist.barrier()
-            print(f"before iter, iterator_stop: {iterator_stop}\n")
+            dataloader_val.batch_sampler.set_epoch(epoch)
             for batch_idx, batch in enumerate(dataloader_val):
                 if self.use_ddp or self.use_fsdp:
                     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
-                    if epoch >= 1:
-                        print(f"iterator_stop: {iterator_stop}\n")
                     if iterator_stop > 0:
                         break
                 time1 = time.perf_counter()
@@ -417,7 +418,7 @@
                     # Apply weighted averaging for loss and stats
                     loss = (loss * weight.type(loss.dtype)).sum()
                     # if distributed, this method can also apply all_reduce()
-                    stats, weight = recursive_average(stats, weight, distributed=True)
+                    # stats, weight = recursive_average(stats, weight, distributed=True)
                     if self.use_ddp or self.use_fsdp:
                         dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                     # Now weight is summation over all workers
@@ -432,15 +433,15 @@
                 self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                 if "acc" in stats:
                     self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
-                # if self.use_ddp or self.use_fsdp:
-                #     val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
-                #     val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
-                #     dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
-                #     dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
-                #     self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
-                #     self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+                if self.use_ddp or self.use_fsdp:
+                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                 
-                batch_num_epoch = -1
+                batch_num_epoch = 1
                 if hasattr(dataloader_val, "__len__"):
                     batch_num_epoch = len(dataloader_val)
                 self.log(epoch, batch_idx,
@@ -461,10 +462,9 @@
         self.val_acc_list.append(self.val_acc_avg)
         model.train()
 
-
-
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
+            iterator_stop = torch.tensor(0).to(self.device)
         
         
     def log(self,

--
Gitblit v1.9.1