| | |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.utils.register import registry_tables |
| | | from funasr.register import tables |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | @hydra.main(config_name=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | | import pdb; pdb.set_trace() |
| | | if kwargs.get("debug", False): |
| | | import pdb; pdb.set_trace() |
| | | |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms"))) |
| | |
| | | # preprocess_config(kwargs) |
| | | # import pdb; pdb.set_trace() |
| | | # set random seed |
| | | registry_tables.print() |
| | | tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | |
| | | |
| | | tokenizer = kwargs.get("tokenizer", None) |
| | | if tokenizer is not None: |
| | | tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower()) |
| | | tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower()) |
| | | tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) |
| | | kwargs["tokenizer"] = tokenizer |
| | | |
| | | # build frontend if frontend is none None |
| | | frontend = kwargs.get("frontend", None) |
| | | if frontend is not None: |
| | | frontend_class = registry_tables.frontend_classes.get(frontend.lower()) |
| | | frontend_class = tables.frontend_classes.get(frontend.lower()) |
| | | frontend = frontend_class(**kwargs["frontend_conf"]) |
| | | kwargs["frontend"] = frontend |
| | | kwargs["input_size"] = frontend.output_size() |
| | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # build model |
| | | model_class = registry_tables.model_classes.get(kwargs["model"].lower()) |
| | | model_class = tables.model_classes.get(kwargs["model"].lower()) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list)) |
| | | |
| | | |
| | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # dataset |
| | | dataset_class = registry_tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower()) |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower()) |
| | | dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) |
| | | |
| | | # dataloader |
| | | batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") |
| | | batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower()) |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler.lower()) |
| | | if batch_sampler is not None: |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | |
| | | pin_memory=True) |
| | | |
| | | |
| | | |
| | | trainer = Trainer( |
| | | model=model, |
| | | optim=optim, |