From 873cfae5c347b940e38e853d8579a6b4e85ada05 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 24 三月 2024 00:45:45 +0800
Subject: [PATCH] update

---
 funasr/train_utils/trainer.py |   33 ++++++++++++++++++++++++++++++---
 1 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 3e83581..d0023fd 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -248,8 +248,15 @@
         optim.zero_grad()
         speed_stats = {}
         time5 = time.perf_counter()
-        
+        iterator_stop = torch.tensor(0).to(self.device)
+        dist.barrier()
+        print(f"before iter, iterator_stop: {iterator_stop}\n")
+        dataloader_train.batch_sampler.set_epoch(epoch)
         for batch_idx, batch in enumerate(dataloader_train):
+            if self.use_ddp or self.use_fsdp:
+                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                if iterator_stop > 0:
+                    break
             self.batch_total += 1
             time1 = time.perf_counter()
             speed_stats["data_load"] = f"{time1-time5:0.3f}"
@@ -356,7 +363,11 @@
             if (batch_idx+1) % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
 
-        
+        else:
+            if self.use_ddp or self.use_fsdp:
+                iterator_stop.fill_(1)
+                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
         
@@ -383,7 +394,16 @@
             
             speed_stats = {}
             time5 = time.perf_counter()
+            iterator_stop = torch.tensor(0).to(self.device)
+            dist.barrier()
+            print(f"before iter, iterator_stop: {iterator_stop}\n")
             for batch_idx, batch in enumerate(dataloader_val):
+                if self.use_ddp or self.use_fsdp:
+                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                    if epoch >= 1:
+                        print(f"iterator_stop: {iterator_stop}\n")
+                    if iterator_stop > 0:
+                        break
                 time1 = time.perf_counter()
                 speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                 batch = to_device(batch, self.device)
@@ -433,9 +453,16 @@
                          tag="val",
                          )
 
+            else:
+                if self.use_ddp or self.use_fsdp:
+                    iterator_stop.fill_(1)
+                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                    
         self.val_acc_list.append(self.val_acc_avg)
         model.train()
-        
+
+
+
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
         

--
Gitblit v1.9.1