九耳
2023-02-28 ee06cb9c6870d9e1579015aabfe1a84a61a5c087
funasr/tasks/abs_task.py
@@ -1350,10 +1350,12 @@
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
                                                   mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
                                                   mode="eval")
            elif args.dataset_type == "small":
                train_iter_factory = cls.build_iter_factory(