游雁
2024-03-24 a70f5b3edf22ac889724aa9a06cefbb316374b28
finetune
2个文件已修改
23 ■■■■ 已修改文件
funasr/bin/train.py 21 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -173,10 +173,10 @@
    except:
        writer = None
    # if use_ddp or use_fsdp:
    #     context = Join([model])
    # else:
    context = nullcontext()
    if use_ddp or use_fsdp:
        context = Join([model])
    else:
        context = nullcontext()
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
@@ -192,13 +192,14 @@
                                epoch=epoch,
                                writer=writer
                                )
        with context:
            trainer.validate_epoch(
                model=model,
                dataloader_val=dataloader_val,
                epoch=epoch,
                writer=writer
            )
        scheduler.step()
        trainer.validate_epoch(
            model=model,
            dataloader_val=dataloader_val,
            epoch=epoch,
            writer=writer
        )
        
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
funasr/train_utils/trainer.py
@@ -398,7 +398,7 @@
            speed_stats = {}
            time5 = time.perf_counter()
            # iterator_stop = torch.tensor(0).to(self.device)
            dataloader_val.batch_sampler.set_epoch(epoch)
            for batch_idx, batch in enumerate(dataloader_val):
                # if self.use_ddp or self.use_fsdp:
                #     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)