游雁
2024-03-24 ed952ff630a1ad89df488d3c657ef736129e6c10
funasr/train_utils/trainer.py
@@ -371,8 +371,7 @@
                
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        iterator_stop = torch.tensor(0).to(self.device)
            iterator_stop = torch.tensor(0).to(self.device)
        
        
@@ -402,12 +401,10 @@
            iterator_stop = torch.tensor(0).to(self.device)
            dataloader_val.batch_sampler.set_epoch(epoch)
            for batch_idx, batch in enumerate(dataloader_val):
                # if self.use_ddp or self.use_fsdp:
                #     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                #     if epoch >= 1:
                #         print(f"iterator_stop: {iterator_stop}\n")
                #     if iterator_stop > 0:
                #         break
                if self.use_ddp or self.use_fsdp:
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                    if iterator_stop > 0:
                        break
                time1 = time.perf_counter()
                speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                batch = to_device(batch, self.device)
@@ -467,7 +464,7 @@
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        iterator_stop = torch.tensor(0).to(self.device)
            iterator_stop = torch.tensor(0).to(self.device)
        
        
    def log(self,