| | |
| | | |
| | | verbose = kwargs.get("verbose", False) |
| | | |
| | | export_name = model.export_name + ".onnx" |
| | | if isinstance(model.export_name, str): |
| | | export_name = model.export_name + ".onnx" |
| | | else: |
| | | export_name = model.export_name() |
| | | model_path = os.path.join(export_dir, export_name) |
| | | torch.onnx.export( |
| | | model, |
| | |
| | | import onnx |
| | | |
| | | quant_model_path = model_path.replace(".onnx", "_quant.onnx") |
| | | if not os.path.exists(quant_model_path): |
| | | onnx_model = onnx.load(model_path) |
| | | nodes = [n.name for n in onnx_model.graph.node] |
| | | nodes_to_exclude = [ |
| | | m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m |
| | | ] |
| | | quantize_dynamic( |
| | | model_input=model_path, |
| | | model_output=quant_model_path, |
| | | op_types_to_quantize=["MatMul"], |
| | | per_channel=True, |
| | | reduce_range=False, |
| | | weight_type=QuantType.QUInt8, |
| | | nodes_to_exclude=nodes_to_exclude, |
| | | ) |
| | | onnx_model = onnx.load(model_path) |
| | | nodes = [n.name for n in onnx_model.graph.node] |
| | | nodes_to_exclude = [ |
| | | m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m |
| | | ] |
| | | print("Quantizing model from {} to {}".format(model_path, quant_model_path)) |
| | | quantize_dynamic( |
| | | model_input=model_path, |
| | | model_output=quant_model_path, |
| | | op_types_to_quantize=["MatMul"], |
| | | per_channel=True, |
| | | reduce_range=False, |
| | | weight_type=QuantType.QUInt8, |
| | | nodes_to_exclude=nodes_to_exclude, |
| | | ) |
| | | |
| | | |
| | | def _torchscripts(model, path, device="cuda"): |
| | | dummy_input = model.export_dummy_inputs() |
| | | |
| | | |
| | | if device == "cuda": |
| | | model = model.cuda() |
| | | if isinstance(dummy_input, torch.Tensor): |
| | | dummy_input = dummy_input.cuda() |
| | | else: |
| | | dummy_input = tuple([i.cuda() for i in dummy_input]) |
| | | |
| | | |
| | | model_script = torch.jit.trace(model, dummy_input) |
| | | model_script.save(os.path.join(path, f"{model.export_name}.torchscript")) |
| | | if isinstance(model.export_name, str): |
| | | model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript"))) |
| | | else: |
| | | model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))) |
| | | |
| | | |
| | | def _bladedisc_opt(model, model_inputs, enable_fp16=True): |