| funasr/datasets/small_datasets/sequence_iter_factory.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -83,7 +83,7 @@ args.max_update = len(bs_list) * args.max_epoch logging.info("Max update: {}".format(args.max_update)) if args.distributed: if args.distributed and mode=="train": world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() for batch in batches: