| | |
| | | return {k: to_plain_list(v) for k, v in cfg_item.items()} |
| | | else: |
| | | return cfg_item |
| | | |
| | | kwargs = to_plain_list(cfg) |
| | | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) |
| | | |
| | | logging.basicConfig(level=log_level) |
| | | kwargs = to_plain_list(cfg) |
| | | |
| | | if kwargs.get("debug", False): |
| | | import pdb; pdb.set_trace() |
| | | import pdb |
| | | |
| | | pdb.set_trace() |
| | | |
| | | if "device" not in kwargs: |
| | | kwargs["device"] = "cpu" |
| | | model = AutoModel(**kwargs) |
| | | |
| | | res = model.export(input=kwargs.get("input", None), |
| | | type=kwargs.get("type", "onnx"), |
| | | quantize=kwargs.get("quantize", False), |
| | | fallback_num=kwargs.get("fallback-num", 5), |
| | | calib_num=kwargs.get("calib_num", 100), |
| | | opset_version=kwargs.get("opset_version", 14), |
| | | ) |
| | | |
| | | res = model.export( |
| | | input=kwargs.get("input", None), |
| | | type=kwargs.get("type", "onnx"), |
| | | quantize=kwargs.get("quantize", False), |
| | | fallback_num=kwargs.get("fallback-num", 5), |
| | | calib_num=kwargs.get("calib_num", 100), |
| | | opset_version=kwargs.get("opset_version", 14), |
| | | ) |
| | | print(res) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | main_hydra() |
| | | if __name__ == "__main__": |
| | | main_hydra() |