| | |
| | | import json |
| | | from typing import Union, Dict |
| | | from pathlib import Path |
| | | |
| | | import os |
| | | import logging |
| | | import torch |
| | | |
| | | from funasr.export.models import get_model |
| | | import numpy as np |
| | | import random |
| | | import logging |
| | | import numpy as np |
| | | from pathlib import Path |
| | | from typing import Union, Dict, List |
| | | from funasr.export.models import get_model |
| | | from funasr.utils.types import str2bool, str2triple_str |
| | | # torch_version = float(".".join(torch.__version__.split(".")[:2])) |
| | | # assert torch_version > 1.9 |
| | |
| | | |
| | | # export encoder1 |
| | | self.export_config["model_name"] = "model" |
| | | models = get_model( |
| | | model = get_model( |
| | | model, |
| | | self.export_config, |
| | | ) |
| | | if not isinstance(models, tuple): |
| | | models = (models,) |
| | | |
| | | for i, model in enumerate(models): |
| | | if isinstance(model, List): |
| | | for m in model: |
| | | m.eval() |
| | | if self.onnx: |
| | | self._export_onnx(m, verbose, export_dir) |
| | | else: |
| | | self._export_torchscripts(m, verbose, export_dir) |
| | | print("output dir: {}".format(export_dir)) |
| | | else: |
| | | model.eval() |
| | | # self._export_onnx(model, verbose, export_dir) |
| | | if self.onnx: |
| | | self._export_onnx(model, verbose, export_dir) |
| | | else: |
| | | self._export_torchscripts(model, verbose, export_dir) |
| | | |
| | | print("output dir: {}".format(export_dir)) |
| | | |
| | | |
| | |
| | | # model_script = torch.jit.script(model) |
| | | model_script = model #torch.jit.trace(model) |
| | | model_path = os.path.join(path, f'{model.model_name}.onnx') |
| | | if not os.path.exists(model_path): |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | | model_path, |
| | | verbose=verbose, |
| | | opset_version=14, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | | ) |
| | | # if not os.path.exists(model_path): |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | | model_path, |
| | | verbose=verbose, |
| | | opset_version=14, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | | ) |
| | | |
| | | if self.quant: |
| | | from onnxruntime.quantization import QuantType, quantize_dynamic |