#!/usr/bin/env python3 import argparse import torch def average_model(input_files, output_file): output_model = {} for ckpt_path in input_files: model_params = torch.load(ckpt_path, map_location="cpu") for key, value in model_params.items(): if key not in output_model: output_model[key] = value else: output_model[key] += value for key in output_model.keys(): output_model[key] /= len(input_files) torch.save(output_model, output_file) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("output_file") parser.add_argument("input_files", nargs='+') args = parser.parse_args() average_model(args.input_files, args.output_file)