| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | |
| | | import os |
| | | import sys |
| | | import torch |
| | |
| | | |
| | | |
| | | def main(**kwargs): |
| | | # preprocess_config(kwargs) |
| | | # import pdb; pdb.set_trace() |
| | | print(kwargs) |
| | | # set random seed |
| | | tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) |
| | |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | # dataset |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf")) |
| | | dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler_val = None |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | |
| | | |
| | | dataloader_val = torch.utils.data.DataLoader(dataset_val, |
| | | collate_fn=dataset_val.collator, |
| | | batch_sampler=batch_sampler_val, |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | trainer = Trainer( |
| | | model=model, |
| | | optim=optim, |
| | | scheduler=scheduler, |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=None, |
| | | dataloader_val=dataloader_val, |
| | | local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | | output_dir=kwargs.get("output_dir", "./exp"), |
| | | resume=kwargs.get("resume", True), |
| | | **kwargs.get("train_conf"), |
| | | ) |
| | | trainer.run() |