游雁
2024-03-24 873cfae5c347b940e38e853d8579a6b4e85ada05
funasr/train_utils/trainer.py
@@ -249,6 +249,9 @@
        speed_stats = {}
        time5 = time.perf_counter()
        iterator_stop = torch.tensor(0).to(self.device)
        dist.barrier()
        print(f"before iter, iterator_stop: {iterator_stop}\n")
        dataloader_train.batch_sampler.set_epoch(epoch)
        for batch_idx, batch in enumerate(dataloader_train):
            if self.use_ddp or self.use_fsdp:
                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
@@ -392,9 +395,13 @@
            speed_stats = {}
            time5 = time.perf_counter()
            iterator_stop = torch.tensor(0).to(self.device)
            dist.barrier()
            print(f"before iter, iterator_stop: {iterator_stop}\n")
            for batch_idx, batch in enumerate(dataloader_val):
                if self.use_ddp or self.use_fsdp:
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                    if epoch >= 1:
                        print(f"iterator_stop: {iterator_stop}\n")
                    if iterator_stop > 0:
                        break
                time1 = time.perf_counter()