From 23e7ddebccd3b05cf7ef89809bcfe565ad6dfa1f Mon Sep 17 00:00:00 2001
From: majic31 <majic31@163.com>
Date: 星期二, 24 十二月 2024 10:00:14 +0800
Subject: [PATCH] Fix the variable name (#2328)

---
 funasr/utils/export_utils.py |  133 +++++++++++++++++++++++++++++++++++++-------
 1 files changed, 112 insertions(+), 21 deletions(-)

diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index a6d0798..ca04d75 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -2,6 +2,10 @@
 import torch
 import functools
 
+import warnings
+
+warnings.filterwarnings("ignore")
+
 
 def export(
     model, data_in=None, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
@@ -35,8 +39,16 @@
             if hasattr(m, "encoder") and hasattr(m, "decoder"):
                 _bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True)
             else:
+                print(f"export_dir: {export_dir}")
                 _torchscripts(m, path=export_dir, device="cuda")
-        print("output dir: {}".format(export_dir))
+
+        elif type == "onnx_fp16":
+            assert (
+                torch.cuda.is_available()
+            ), "Currently onnx_fp16 optimization for FunASR only supports GPU"
+
+            if hasattr(m, "encoder") and hasattr(m, "decoder"):
+                _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True)
 
     return export_dir
 
@@ -50,17 +62,27 @@
     **kwargs,
 ):
 
+    device = kwargs.get("device", "cpu")
     dummy_input = model.export_dummy_inputs()
+
+    if isinstance(dummy_input, torch.Tensor):
+        dummy_input = dummy_input.to(device)
+    else:
+        dummy_input = tuple([input.to(device) for input in dummy_input])
 
     verbose = kwargs.get("verbose", False)
 
-    export_name = model.export_name + ".onnx"
+    if isinstance(model.export_name, str):
+        export_name = model.export_name + ".onnx"
+    else:
+        export_name = model.export_name()
     model_path = os.path.join(export_dir, export_name)
     torch.onnx.export(
         model,
         dummy_input,
         model_path,
         verbose=verbose,
+        do_constant_folding=True,
         opset_version=opset_version,
         input_names=model.export_input_names(),
         output_names=model.export_output_names(),
@@ -68,25 +90,30 @@
     )
 
     if quantize:
-        from onnxruntime.quantization import QuantType, quantize_dynamic
-        import onnx
+        try:
+            from onnxruntime.quantization import QuantType, quantize_dynamic
+            import onnx
+        except:
+            raise RuntimeError(
+                "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`."
+            )
 
         quant_model_path = model_path.replace(".onnx", "_quant.onnx")
-        if not os.path.exists(quant_model_path):
-            onnx_model = onnx.load(model_path)
-            nodes = [n.name for n in onnx_model.graph.node]
-            nodes_to_exclude = [
-                m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
-            ]
-            quantize_dynamic(
-                model_input=model_path,
-                model_output=quant_model_path,
-                op_types_to_quantize=["MatMul"],
-                per_channel=True,
-                reduce_range=False,
-                weight_type=QuantType.QUInt8,
-                nodes_to_exclude=nodes_to_exclude,
-            )
+        onnx_model = onnx.load(model_path)
+        nodes = [n.name for n in onnx_model.graph.node]
+        nodes_to_exclude = [
+            m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
+        ]
+        print("Quantizing model from {} to {}".format(model_path, quant_model_path))
+        quantize_dynamic(
+            model_input=model_path,
+            model_output=quant_model_path,
+            op_types_to_quantize=["MatMul"],
+            per_channel=True,
+            reduce_range=False,
+            weight_type=QuantType.QUInt8,
+            nodes_to_exclude=nodes_to_exclude,
+        )
 
 
 def _torchscripts(model, path, device="cuda"):
@@ -100,7 +127,12 @@
             dummy_input = tuple([i.cuda() for i in dummy_input])
 
     model_script = torch.jit.trace(model, dummy_input)
-    model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
+    if isinstance(model.export_name, str):
+        model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
+    else:
+        model_script.save(
+            os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))
+        )
 
 
 def _bladedisc_opt(model, model_inputs, enable_fp16=True):
@@ -153,7 +185,7 @@
 
     # Rescale encoder modules
     fp16_scale = int(2 * absmax // 65536)
-    print(f"rescale encoder modules with factor={fp16_scale}")
+    print(f"rescale encoder modules with factor={fp16_scale}\n\n")
     model.encoder.model.encoders0.register_forward_pre_hook(
         functools.partial(_rescale_input_hook, scale=fp16_scale),
     )
@@ -194,3 +226,62 @@
     model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
     model_script = torch.jit.trace(model, input_data)
     model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
+
+
+def _onnx_opt_for_encdec(model, path, enable_fp16):
+
+    # Get input data
+    # TODO: better to use real data
+    input_data = model.export_dummy_inputs()
+
+    if isinstance(input_data, torch.Tensor):
+        input_data = input_data.cuda()
+    else:
+        input_data = tuple([i.cuda() for i in input_data])
+
+    # Get input data for decoder module
+    decoder_inputs = list()
+
+    def get_input_hook(m, x):
+        decoder_inputs.extend(list(x))
+
+    hook = model.decoder.register_forward_pre_hook(get_input_hook)
+    model = model.cuda()
+    model(*input_data)
+    hook.remove()
+
+    # Prevent FP16 overflow
+    if enable_fp16:
+        _rescale_encoder_model(model, input_data)
+
+    fp32_model_path = f"{path}/{model.export_name}_hook.onnx"
+    print("*" * 50)
+    print(f"[_onnx_opt_for_encdec(fp32)]: {fp32_model_path}\n\n")
+    if not os.path.exists(fp32_model_path):
+
+        torch.onnx.export(
+            model,
+            input_data,
+            fp32_model_path,
+            verbose=False,
+            do_constant_folding=True,
+            opset_version=13,
+            input_names=model.export_input_names(),
+            output_names=model.export_output_names(),
+            dynamic_axes=model.export_dynamic_axes(),
+        )
+
+    # fp32 to fp16
+    fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx"
+    print("*" * 50)
+    print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n")
+    if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path):
+        try:
+            from onnxconverter_common import float16
+        except:
+            raise RuntimeError(
+                "You are converting the onnx model to fp16, please install onnxconverter-common first. via `pip install onnxconverter-common`."
+            )
+        fp32_onnx_model = onnx.load(fp32_model_path)
+        fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True)
+        onnx.save(fp16_onnx_model, fp16_model_path)

--
Gitblit v1.9.1