游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/train_utils/trainer.py
@@ -147,9 +147,18 @@
        for epoch in range(self.start_epoch, self.max_epoch + 1):
            
            self._train_epoch(epoch)
            
            if self.use_ddp or self.use_fsdp:
                dist.barrier()
            self._validate_epoch(epoch)
            if self.use_ddp or self.use_fsdp:
                dist.barrier()
            if self.rank == 0:
                self._save_checkpoint(epoch)
            
@@ -164,7 +173,10 @@
            
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        self.writer.close()
        if self.writer:
            self.writer.close()
        
    
    def _train_epoch(self, epoch):
@@ -192,7 +204,25 @@
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                # print("before, GPU, memory: {:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                #                      torch.cuda.max_memory_allocated()/1024/1024/1024,
                #                      torch.cuda.memory_reserved()/1024/1024/1024,
                #                      torch.cuda.max_memory_reserved()/1024/1024/1024,
                #                      ))
                retval = self.model(**batch)
                torch.cuda.empty_cache()
                # print("after, GPU, memory: {:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB, "
                #       "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                #                      torch.cuda.max_memory_allocated()/1024/1024/1024,
                #                      torch.cuda.memory_reserved()/1024/1024/1024,
                #                      torch.cuda.max_memory_reserved()/1024/1024/1024,
                #                      ))
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
@@ -230,6 +260,8 @@
                        continue
                
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
                self.optim.step()
                self.scheduler.step()
                # Clear gradients for the next accumulation stage
@@ -243,12 +275,21 @@
            pbar.update(1)
            if self.local_rank == 0:
                gpu_info = "GPU, memory: {:.3f} GB, " \
                           "{:.3f} GB, "\
                           "{:.3f} GB, "\
                           "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
                                             torch.cuda.max_memory_allocated()/1024/1024/1024,
                                             torch.cuda.memory_reserved()/1024/1024/1024,
                                             torch.cuda.max_memory_reserved()/1024/1024/1024,
                                             )
                description = (
                    f"Epoch: {epoch}/{self.max_epoch}, "
                    f"Train epoch: {epoch}/{self.max_epoch}, "
                    f"step {batch_idx}/{len(self.dataloader_train)}, "
                    f"{speed_stats}, "
                    f"(loss: {loss.detach().cpu().item():.3f}), "
                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                    f"{gpu_info}"
                )
                pbar.set_description(description)
                if self.writer:
@@ -306,7 +347,7 @@
                pbar.update(1)
                if self.local_rank == 0:
                    description = (
                        f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
                        f"validation epoch: {epoch}/{self.max_epoch}, "
                        f"step {batch_idx}/{len(self.dataloader_train)}, "
                        f"{speed_stats}, "
                        f"(loss: {loss.detach().cpu().item():.3f}), "