| | |
| | | data_path_and_name_and_type, |
| | | preprocess=preprocess_fn, |
| | | dest_sample_rate=dest_sample_rate, |
| | | speed_perturb=args.speed_perturb if mode=="train" else None, |
| | | speed_perturb=args.speed_perturb if mode == "train" else None, |
| | | ) |
| | | |
| | | # sampler |
| | |
| | | args.max_update = len(bs_list) * args.max_epoch |
| | | logging.info("Max update: {}".format(args.max_update)) |
| | | |
| | | if args.distributed and mode=="train": |
| | | if args.distributed and mode == "train": |
| | | world_size = torch.distributed.get_world_size() |
| | | rank = torch.distributed.get_rank() |
| | | for batch in batches: |