| | |
| | | # dataset |
| | | logging.info("Build dataloader") |
| | | dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")) |
| | | # dataloader = dataloader_class(**kwargs) |
| | | dataloader_tr, dataloader_val = dataloader_class(**kwargs) |
| | | dataloader = dataloader_class(**kwargs) |
| | | # dataloader_tr, dataloader_val = dataloader_class(**kwargs) |
| | | trainer = Trainer(local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | | with context: |
| | | # dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | trainer.train_epoch( |
| | | model=model, |
| | | optim=optim, |
| | |
| | | |
| | | from funasr.register import tables |
| | | |
| | | @tables.register("dataloader_classes", "DataloaderMapStyle") |
| | | # @tables.register("dataloader_classes", "DataloaderMapStyle") |
| | | def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs): |
| | | # dataset |
| | | logging.info("Build dataloader") |
| | |
| | | |
| | | return dataloader_tr, dataloader_val |
| | | |
| | | # @tables.register("dataloader_classes", "DataloaderMapStyle") |
| | | @tables.register("dataloader_classes", "DataloaderMapStyle") |
| | | class DataloaderMapStyle: |
| | | def __init__(self, frontend=None, tokenizer=None, **kwargs): |
| | | # dataset |
| | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | |
| | | |
| | | |
| | |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | dataloader_val.batch_sampler.set_epoch(epoch) |
| | | for batch_idx, batch in enumerate(dataloader_val): |
| | | # if self.use_ddp or self.use_fsdp: |
| | | # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| | | # if epoch >= 1: |
| | | # print(f"iterator_stop: {iterator_stop}\n") |
| | | # if iterator_stop > 0: |
| | | # break |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| | | if iterator_stop > 0: |
| | | break |
| | | time1 = time.perf_counter() |
| | | speed_stats["data_load"] = f"{time1 - time5:0.3f}" |
| | | batch = to_device(batch, self.device) |
| | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | |
| | | |
| | | def log(self, |