游雁
2024-12-23 1e5ef6ed9a6f64ecca7b9ef9481519b271f793a3
funasr/utils/export_utils.py
@@ -1,12 +1,10 @@
import os
import torch
import functools
import onnx
from onnxconverter_common import float16
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore")
def export(
@@ -44,14 +42,13 @@
                print(f"export_dir: {export_dir}")
                _torchscripts(m, path=export_dir, device="cuda")
        elif type=='onnx_fp16':
        elif type == "onnx_fp16":
            assert (
                torch.cuda.is_available()
            ), "Currently onnx_fp16 optimization for FunASR only supports GPU"
            ), "Currently onnx_fp16 optimization for FunASR only supports GPU"
            if hasattr(m, "encoder") and hasattr(m, "decoder"):
                _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True)
                _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True)
    return export_dir
@@ -65,9 +62,13 @@
    **kwargs,
):
    device = kwargs.get("device", "cpu")
    dummy_input = model.export_dummy_inputs()
    dummy_input = (dummy_input[0].to("cuda"), dummy_input[1].to("cuda"))
    if isinstance(dummy_input, torch.Tensor):
        dummy_input = dummy_input.to(device)
    else:
        dummy_input = tuple([input.to(device) for input in dummy_input])
    verbose = kwargs.get("verbose", False)
@@ -89,8 +90,13 @@
    )
    if quantize:
        from onnxruntime.quantization import QuantType, quantize_dynamic
        import onnx
        try:
            from onnxruntime.quantization import QuantType, quantize_dynamic
            import onnx
        except:
            raise RuntimeError(
                "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`."
            )
        quant_model_path = model_path.replace(".onnx", "_quant.onnx")
        onnx_model = onnx.load(model_path)
@@ -112,19 +118,21 @@
def _torchscripts(model, path, device="cuda"):
    dummy_input = model.export_dummy_inputs()
    if device == "cuda":
        model = model.cuda()
        if isinstance(dummy_input, torch.Tensor):
            dummy_input = dummy_input.cuda()
        else:
            dummy_input = tuple([i.cuda() for i in dummy_input])
    model_script = torch.jit.trace(model, dummy_input)
    if isinstance(model.export_name, str):
        model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
    else:
        model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
        model_script.save(
            os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))
        )
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
@@ -220,7 +228,6 @@
    model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
def _onnx_opt_for_encdec(model, path, enable_fp16):
    # Get input data
@@ -262,16 +269,19 @@
            input_names=model.export_input_names(),
            output_names=model.export_output_names(),
            dynamic_axes=model.export_dynamic_axes(),
        )
        )
    # fp32 to fp16
    fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx"
    print("*" * 50)
    print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n")
    if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path):
        try:
            from onnxconverter_common import float16
        except:
            raise RuntimeError(
                "You are converting the onnx model to fp16, please install onnxconverter-common first. via `pip install onnxconverter-common`."
            )
        fp32_onnx_model = onnx.load(fp32_model_path)
        fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True)
        onnx.save(
            fp16_onnx_model, fp16_model_path
        )
        onnx.save(fp16_onnx_model, fp16_model_path)