游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/export/export_model.py
@@ -1,15 +1,12 @@
import json
from typing import Union, Dict
from pathlib import Path
import os
import logging
import torch
from funasr.export.models import get_model
import numpy as np
import random
from funasr.utils.types import str2bool
import logging
import numpy as np
from pathlib import Path
from typing import Union, Dict, List
from funasr.export.models import get_model
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
@@ -59,14 +56,22 @@
            model,
            self.export_config,
        )
        model.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(model, verbose, export_dir)
        if isinstance(model, List):
            for m in model:
                m.eval()
                if self.onnx:
                    self._export_onnx(m, verbose, export_dir)
                else:
                    self._export_torchscripts(m, verbose, export_dir)
                print("output dir: {}".format(export_dir))
        else:
            self._export_torchscripts(model, verbose, export_dir)
        print("output dir: {}".format(export_dir))
            model.eval()
            # self._export_onnx(model, verbose, export_dir)
            if self.onnx:
                self._export_onnx(model, verbose, export_dir)
            else:
                self._export_torchscripts(model, verbose, export_dir)
            print("output dir: {}".format(export_dir))
    def _torch_quantize(self, model):
@@ -192,6 +197,7 @@
                config, model_file, cmvn_file, 'cpu'
            )
            self.frontend = model.frontend
            self.export_config["feats_dim"] = 560
        elif mode.startswith('offline'):
            from funasr.tasks.vad import VADTask
            config = os.path.join(model_dir, 'vad.yaml')
@@ -229,17 +235,17 @@
        # model_script = torch.jit.script(model)
        model_script = model #torch.jit.trace(model)
        model_path = os.path.join(path, f'{model.model_name}.onnx')
        if not os.path.exists(model_path):
            torch.onnx.export(
                model_script,
                dummy_input,
                model_path,
                verbose=verbose,
                opset_version=14,
                input_names=model.get_input_names(),
                output_names=model.get_output_names(),
                dynamic_axes=model.get_dynamic_axes()
            )
        # if not os.path.exists(model_path):
        torch.onnx.export(
            model_script,
            dummy_input,
            model_path,
            verbose=verbose,
            opset_version=14,
            input_names=model.get_input_names(),
            output_names=model.get_output_names(),
            dynamic_axes=model.get_dynamic_axes()
        )
        if self.quant:
            from onnxruntime.quantization import QuantType, quantize_dynamic
@@ -248,7 +254,7 @@
            if not os.path.exists(quant_model_path):
                onnx_model = onnx.load(model_path)
                nodes = [n.name for n in onnx_model.graph.node]
                nodes_to_exclude = [m for m in nodes if 'output' in m]
                nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m  or 'bias_decoder' in m]
                quantize_dynamic(
                    model_input=model_path,
                    model_output=quant_model_path,
@@ -263,7 +269,8 @@
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, required=True)
    # parser.add_argument('--model-name', type=str, required=True)
    parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
    parser.add_argument('--export-dir', type=str, required=True)
    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
@@ -284,4 +291,6 @@
        calib_num=args.calib_num,
        model_revision=args.model_revision,
    )
    export_model.export(args.model_name)
    for model_name in args.model_name:
        print("export model: {}".format(model_name))
        export_model.export(model_name)