| | |
| | | import os |
| | | import torch |
| | | import functools |
| | | import onnx |
| | | from onnxconverter_common import float16 |
| | | |
| | | import warnings |
| | | warnings.filterwarnings("ignore") |
| | | |
| | | warnings.filterwarnings("ignore") |
| | | |
| | | |
| | | def export( |
| | |
| | | print(f"export_dir: {export_dir}") |
| | | _torchscripts(m, path=export_dir, device="cuda") |
| | | |
| | | |
| | | elif type=='onnx_fp16': |
| | | elif type == "onnx_fp16": |
| | | assert ( |
| | | torch.cuda.is_available() |
| | | ), "Currently onnx_fp16 optimization for FunASR only supports GPU" |
| | |
| | | else: |
| | | dummy_input = tuple([input.to(device) for input in dummy_input]) |
| | | |
| | | |
| | | verbose = kwargs.get("verbose", False) |
| | | |
| | | if isinstance(model.export_name, str): |
| | |
| | | ) |
| | | |
| | | if quantize: |
| | | try: |
| | | from onnxruntime.quantization import QuantType, quantize_dynamic |
| | | import onnx |
| | | except: |
| | | raise RuntimeError( |
| | | "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`." |
| | | ) |
| | | |
| | | quant_model_path = model_path.replace(".onnx", "_quant.onnx") |
| | | onnx_model = onnx.load(model_path) |
| | |
| | | if isinstance(model.export_name, str): |
| | | model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript"))) |
| | | else: |
| | | model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))) |
| | | model_script.save( |
| | | os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")) |
| | | ) |
| | | |
| | | |
| | | def _bladedisc_opt(model, model_inputs, enable_fp16=True): |
| | |
| | | model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript")) |
| | | |
| | | |
| | | |
| | | def _onnx_opt_for_encdec(model, path, enable_fp16): |
| | | |
| | | # Get input data |
| | |
| | | dynamic_axes=model.export_dynamic_axes(), |
| | | ) |
| | | |
| | | |
| | | # fp32 to fp16 |
| | | fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx" |
| | | print("*" * 50) |
| | | print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n") |
| | | if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path): |
| | | try: |
| | | from onnxconverter_common import float16 |
| | | except: |
| | | raise RuntimeError( |
| | | "You are converting the onnx model to fp16, please install onnxconverter-common first. via `pip install onnxconverter-common`." |
| | | ) |
| | | fp32_onnx_model = onnx.load(fp32_model_path) |
| | | fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True) |
| | | onnx.save( |
| | | fp16_onnx_model, fp16_model_path |
| | | ) |
| | | onnx.save(fp16_onnx_model, fp16_model_path) |