| | |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, |
| | | **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler_val = None |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | |
| | | |
| | | dataloader_val = torch.utils.data.DataLoader(dataset_val, |
| | | collate_fn=dataset_val.collator, |
| | | batch_sampler=batch_sampler_val, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | trainer = Trainer( |
| | | model=model, |
| | | optim=optim, |
| | | scheduler=scheduler, |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=None, |
| | | dataloader_val=dataloader_val, |
| | | local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |