| | |
| | | import json |
| | | from typing import Union, Dict |
| | | from pathlib import Path |
| | | |
| | | import os |
| | | import logging |
| | | import torch |
| | | |
| | | from funasr.export.models import get_model |
| | | import numpy as np |
| | | import random |
| | | import logging |
| | | import numpy as np |
| | | from pathlib import Path |
| | | from typing import Union, Dict, List |
| | | from funasr.export.models import get_model |
| | | from funasr.utils.types import str2bool, str2triple_str |
| | | # torch_version = float(".".join(torch.__version__.split(".")[:2])) |
| | | # assert torch_version > 1.9 |
| | |
| | | model, |
| | | self.export_config, |
| | | ) |
| | | model.eval() |
| | | # self._export_onnx(model, verbose, export_dir) |
| | | if self.onnx: |
| | | self._export_onnx(model, verbose, export_dir) |
| | | if isinstance(model, List): |
| | | for m in model: |
| | | m.eval() |
| | | if self.onnx: |
| | | self._export_onnx(m, verbose, export_dir) |
| | | else: |
| | | self._export_torchscripts(m, verbose, export_dir) |
| | | print("output dir: {}".format(export_dir)) |
| | | else: |
| | | self._export_torchscripts(model, verbose, export_dir) |
| | | |
| | | print("output dir: {}".format(export_dir)) |
| | | model.eval() |
| | | # self._export_onnx(model, verbose, export_dir) |
| | | if self.onnx: |
| | | self._export_onnx(model, verbose, export_dir) |
| | | else: |
| | | self._export_torchscripts(model, verbose, export_dir) |
| | | print("output dir: {}".format(export_dir)) |
| | | |
| | | |
| | | def _torch_quantize(self, model): |
| | |
| | | # model_script = torch.jit.script(model) |
| | | model_script = model #torch.jit.trace(model) |
| | | model_path = os.path.join(path, f'{model.model_name}.onnx') |
| | | if not os.path.exists(model_path): |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | | model_path, |
| | | verbose=verbose, |
| | | opset_version=14, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | | ) |
| | | # if not os.path.exists(model_path): |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | | model_path, |
| | | verbose=verbose, |
| | | opset_version=14, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | | ) |
| | | |
| | | if self.quant: |
| | | from onnxruntime.quantization import QuantType, quantize_dynamic |
| | |
| | | if not os.path.exists(quant_model_path): |
| | | onnx_model = onnx.load(model_path) |
| | | nodes = [n.name for n in onnx_model.graph.node] |
| | | nodes_to_exclude = [m for m in nodes if 'output' in m] |
| | | nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m] |
| | | quantize_dynamic( |
| | | model_input=model_path, |
| | | model_output=quant_model_path, |