funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -57,7 +57,7 @@ data_path_and_name_and_type, preprocess=preprocess_fn, dest_sample_rate=dest_sample_rate, speed_perturb=args.speed_perturb, speed_perturb=args.speed_perturb if mode == "train" else None, ) # sampler @@ -84,7 +84,7 @@ args.max_update = len(bs_list) * args.max_epoch logging.info("Max update: {}".format(args.max_update)) if args.distributed and mode=="train": if args.distributed and mode == "train": world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() for batch in batches: