From 5d74e107fc5696b70e75003c278f8babd17161e8 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 24 三月 2024 00:58:56 +0800
Subject: [PATCH] update

---
 funasr/train_utils/trainer.py |   65 +++++++++++++++++++++++---------
 1 files changed, 47 insertions(+), 18 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index c443c6f..23c18d9 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -198,6 +198,8 @@
                 for k in dst_state.keys():
                     if not k.startswith("module.") and "module."+k in src_state.keys():
                         k_ddp = "module."+k
+                    elif k.startswith("module.") and "module."+k not in src_state.keys():
+                        k_ddp = k.replace("module.", "", 1)
                     else:
                         k_ddp = k
                     if k_ddp in src_state.keys():
@@ -246,8 +248,15 @@
         optim.zero_grad()
         speed_stats = {}
         time5 = time.perf_counter()
-        
+        iterator_stop = torch.tensor(0).to(self.device)
+        dist.barrier()
+        print(f"before iter, iterator_stop: {iterator_stop}\n")
+        dataloader_train.batch_sampler.set_epoch(epoch)
         for batch_idx, batch in enumerate(dataloader_train):
+            if self.use_ddp or self.use_fsdp:
+                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                if iterator_stop > 0:
+                    break
             self.batch_total += 1
             time1 = time.perf_counter()
             speed_stats["data_load"] = f"{time1-time5:0.3f}"
@@ -288,13 +297,13 @@
                 self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                 if "acc" in stats:
                     self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
-                if self.use_ddp or self.use_fsdp:
-                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
-                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
-                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
-                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
-                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                # if self.use_ddp or self.use_fsdp:
+                #     train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+                #     train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+                #     dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                #     dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                #     self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                #     self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
                 
             
             # Perform an optimizer step only after accumulating enough gradients
@@ -354,7 +363,11 @@
             if (batch_idx+1) % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
 
-        
+        else:
+            if self.use_ddp or self.use_fsdp:
+                iterator_stop.fill_(1)
+                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
         
@@ -381,7 +394,16 @@
             
             speed_stats = {}
             time5 = time.perf_counter()
+            iterator_stop = torch.tensor(0).to(self.device)
+            dist.barrier()
+            print(f"before iter, iterator_stop: {iterator_stop}\n")
             for batch_idx, batch in enumerate(dataloader_val):
+                if self.use_ddp or self.use_fsdp:
+                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                    if epoch >= 1:
+                        print(f"iterator_stop: {iterator_stop}\n")
+                    if iterator_stop > 0:
+                        break
                 time1 = time.perf_counter()
                 speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                 batch = to_device(batch, self.device)
@@ -395,7 +417,7 @@
                     # Apply weighted averaging for loss and stats
                     loss = (loss * weight.type(loss.dtype)).sum()
                     # if distributed, this method can also apply all_reduce()
-                    stats, weight = recursive_average(stats, weight, distributed=True)
+                    # stats, weight = recursive_average(stats, weight, distributed=True)
                     if self.use_ddp or self.use_fsdp:
                         dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                     # Now weight is summation over all workers
@@ -410,13 +432,13 @@
                 self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                 if "acc" in stats:
                     self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
-                if self.use_ddp or self.use_fsdp:
-                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
-                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
-                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
-                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
-                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+                # if self.use_ddp or self.use_fsdp:
+                #     val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+                #     val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+                #     dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                #     dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                #     self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                #     self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                 
                 batch_num_epoch = -1
                 if hasattr(dataloader_val, "__len__"):
@@ -431,9 +453,16 @@
                          tag="val",
                          )
 
+            else:
+                if self.use_ddp or self.use_fsdp:
+                    iterator_stop.fill_(1)
+                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                    
         self.val_acc_list.append(self.val_acc_avg)
         model.train()
-        
+
+
+
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
         

--
Gitblit v1.9.1