| | |
| | | "and exclude_keys excludes keys of model states for the initialization." |
| | | "e.g.\n" |
| | | " # Load all parameters" |
| | | " --init_param some/where/model.pth\n" |
| | | " --init_param some/where/model.pb\n" |
| | | " # Load only decoder parameters" |
| | | " --init_param some/where/model.pth:decoder:decoder\n" |
| | | " --init_param some/where/model.pb:decoder:decoder\n" |
| | | " # Load only decoder parameters excluding decoder.embed" |
| | | " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n" |
| | | " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n", |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n" |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n", |
| | | ) |
| | | group.add_argument( |
| | | "--ignore_init_mismatch", |
| | |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="eval") |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |