funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -57,6 +57,7 @@ data_path_and_name_and_type, preprocess=preprocess_fn, dest_sample_rate=dest_sample_rate, speed_perturb=args.speed_perturb if mode=="train" else None, ) # sampler @@ -83,7 +84,7 @@ args.max_update = len(bs_list) * args.max_epoch logging.info("Max update: {}".format(args.max_update)) if args.distributed: if args.distributed and mode=="train": world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() for batch in batches: