游雁
2024-03-27 9b4e9cc8a0311e5243d69b73ed073e7ea441982e
funasr/bin/train.py
@@ -150,8 +150,8 @@
    # dataset
    logging.info("Build dataloader")
    dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
    # dataloader = dataloader_class(**kwargs)
    dataloader_tr, dataloader_val = dataloader_class(**kwargs)
    dataloader = dataloader_class(**kwargs)
    # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
    trainer = Trainer(local_rank=local_rank,
                      use_ddp=use_ddp,
                      use_fsdp=use_fsdp,
@@ -173,15 +173,15 @@
    except:
        writer = None
    if use_ddp or use_fsdp:
        context = Join([model])
    else:
        context = nullcontext()
    # if use_ddp or use_fsdp:
    #     context = Join([model])
    # else:
    #     context = nullcontext()
    context = nullcontext()
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        with context:
            # dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
            dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
            trainer.train_epoch(
                                model=model,
                                optim=optim,
@@ -214,7 +214,7 @@
    if trainer.rank == 0:
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
    trainer.close()