From 507f821d7ab0a51a4f01b8557fe38e0bcf0d14f6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 26 三月 2024 16:45:10 +0800
Subject: [PATCH] update

---
 funasr/train_utils/trainer.py |   21 +++++++++------------
 1 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index c665394..56ec604 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -190,7 +190,7 @@
         if self.resume:
             ckpt = os.path.join(self.output_dir, "model.pt")
             if os.path.isfile(ckpt):
-                checkpoint = torch.load(ckpt)
+                checkpoint = torch.load(ckpt, map_location="cpu")
                 self.start_epoch = checkpoint['epoch'] + 1
                 # self.model.load_state_dict(checkpoint['state_dict'])
                 src_state = checkpoint['state_dict']
@@ -215,7 +215,7 @@
                 
                 self.val_acc_list = checkpoint["acc"]
                 self.step_or_epoch = checkpoint["step_or_epoch"]
-                
+                model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
             else:
                 print(f"No checkpoint found at '{ckpt}', does not resume status!")
@@ -371,8 +371,7 @@
                 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-
-        iterator_stop = torch.tensor(0).to(self.device)
+            iterator_stop = torch.tensor(0).to(self.device)
         
         
 
@@ -402,12 +401,10 @@
             iterator_stop = torch.tensor(0).to(self.device)
             dataloader_val.batch_sampler.set_epoch(epoch)
             for batch_idx, batch in enumerate(dataloader_val):
-                # if self.use_ddp or self.use_fsdp:
-                #     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
-                #     if epoch >= 1:
-                #         print(f"iterator_stop: {iterator_stop}\n")
-                #     if iterator_stop > 0:
-                #         break
+                if self.use_ddp or self.use_fsdp:
+                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                    if iterator_stop > 0:
+                        break
                 time1 = time.perf_counter()
                 speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                 batch = to_device(batch, self.device)
@@ -443,7 +440,7 @@
                     dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
                     self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
                     self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
-                
+                time5 = time.perf_counter()
                 batch_num_epoch = 1
                 if hasattr(dataloader_val, "__len__"):
                     batch_num_epoch = len(dataloader_val)
@@ -467,7 +464,7 @@
 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        iterator_stop = torch.tensor(0).to(self.device)
+            iterator_stop = torch.tensor(0).to(self.device)
         
         
     def log(self,

--
Gitblit v1.9.1