游雁
2024-03-24 ed22e34d654c47017962d3e5758d3a351d8826ab
finetune
2个文件已修改
43 ■■■■ 已修改文件
funasr/bin/train.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 35 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -173,11 +173,11 @@
    except:
        writer = None
    if use_ddp or use_fsdp:
        context = Join([model])
    else:
    # if use_ddp or use_fsdp:
    #     context = Join([model])
    # else:
    #     context = nullcontext()
        context = nullcontext()
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        with context:
funasr/train_utils/trainer.py
@@ -250,14 +250,14 @@
        optim.zero_grad()
        speed_stats = {}
        time5 = time.perf_counter()
        # iterator_stop = torch.tensor(0).to(self.device)
        iterator_stop = torch.tensor(0).to(self.device)
        dataloader_train.batch_sampler.set_epoch(epoch)
        for batch_idx, batch in enumerate(dataloader_train):
            # if self.use_ddp or self.use_fsdp:
            #     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
            #     if iterator_stop > 0:
            #         break
            if self.use_ddp or self.use_fsdp:
                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                if iterator_stop > 0:
                    break
            self.batch_total += 1
            time1 = time.perf_counter()
            speed_stats["data_load"] = f"{time1-time5:0.3f}"
@@ -340,7 +340,7 @@
    
                speed_stats["total_time"] = total_time
                lr = scheduler.get_last_lr()[0]
                batch_num_epoch = -1
                batch_num_epoch = 1
                if hasattr(dataloader_train, "__len__"):
                    batch_num_epoch = len(dataloader_train)
                self.log(epoch, batch_idx,
@@ -364,13 +364,15 @@
            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
        # else:
        #     if self.use_ddp or self.use_fsdp:
        #         iterator_stop.fill_(1)
        #         dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
        else:
            if self.use_ddp or self.use_fsdp:
                iterator_stop.fill_(1)
                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        iterator_stop = torch.tensor(0).to(self.device)
        
        
@@ -397,7 +399,7 @@
            
            speed_stats = {}
            time5 = time.perf_counter()
            # iterator_stop = torch.tensor(0).to(self.device)
            iterator_stop = torch.tensor(0).to(self.device)
            dataloader_val.batch_sampler.set_epoch(epoch)
            for batch_idx, batch in enumerate(dataloader_val):
                # if self.use_ddp or self.use_fsdp:
@@ -442,7 +444,7 @@
                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                
                batch_num_epoch = -1
                batch_num_epoch = 1
                if hasattr(dataloader_val, "__len__"):
                    batch_num_epoch = len(dataloader_val)
                self.log(epoch, batch_idx,
@@ -455,16 +457,17 @@
                         tag="val",
                         )
            # else:
            #     if self.use_ddp or self.use_fsdp:
            #         iterator_stop.fill_(1)
            #         dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
            else:
                if self.use_ddp or self.use_fsdp:
                    iterator_stop.fill_(1)
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                    
        self.val_acc_list.append(self.val_acc_avg)
        model.train()
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        iterator_stop = torch.tensor(0).to(self.device)
        
        
    def log(self,