游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/export/export_model.py
@@ -1,16 +1,12 @@
import json
from typing import Union, Dict
from pathlib import Path
from typeguard import check_argument_types
import os
import logging
import torch
from funasr.export.models import get_model
import numpy as np
import random
from funasr.utils.types import str2bool
import logging
import numpy as np
from pathlib import Path
from typing import Union, Dict, List
from funasr.export.models import get_model
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
@@ -19,28 +15,29 @@
        self,
        cache_dir: Union[Path, str] = None,
        onnx: bool = True,
        device: str = "cpu",
        quant: bool = True,
        fallback_num: int = 0,
        audio_in: str = None,
        calib_num: int = 200,
        model_revision: str = None,
    ):
        assert check_argument_types()
        self.set_all_random_seed(0)
        if cache_dir is None:
            cache_dir = Path.home() / ".cache" / "export"
        self.cache_dir = Path(cache_dir)
        self.cache_dir = cache_dir
        self.export_config = dict(
            feats_dim=560,
            onnx=False,
        )
        print("output dir: {}".format(self.cache_dir))
        self.onnx = onnx
        self.device = device
        self.quant = quant
        self.fallback_num = fallback_num
        self.frontend = None
        self.audio_in = audio_in
        self.calib_num = calib_num
        self.model_revision = model_revision
        
    def _export(
@@ -50,7 +47,7 @@
        verbose: bool = False,
    ):
        export_dir = self.cache_dir / tag_name.replace(' ', '-')
        export_dir = self.cache_dir
        os.makedirs(export_dir, exist_ok=True)
        # export encoder1
@@ -59,14 +56,22 @@
            model,
            self.export_config,
        )
        model.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(model, verbose, export_dir)
        if isinstance(model, List):
            for m in model:
                m.eval()
                if self.onnx:
                    self._export_onnx(m, verbose, export_dir)
                else:
                    self._export_torchscripts(m, verbose, export_dir)
                print("output dir: {}".format(export_dir))
        else:
            self._export_torchscripts(model, verbose, export_dir)
        print("output dir: {}".format(export_dir))
            model.eval()
            # self._export_onnx(model, verbose, export_dir)
            if self.onnx:
                self._export_onnx(model, verbose, export_dir)
            else:
                self._export_torchscripts(model, verbose, export_dir)
            print("output dir: {}".format(export_dir))
    def _torch_quantize(self, model):
@@ -111,6 +116,10 @@
            dummy_input = model.get_dummy_inputs(enc_size)
        else:
            dummy_input = model.get_dummy_inputs()
        if self.device == 'cuda':
            model = model.cuda()
            dummy_input = tuple([i.cuda() for i in dummy_input])
        # model_script = torch.jit.script(model)
        model_script = torch.jit.trace(model, dummy_input)
@@ -161,31 +170,59 @@
    
    def export(self,
               tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
               mode: str = 'paraformer',
               mode: str = None,
               ):
        
        model_dir = tag_name
        if model_dir.startswith('damo/'):
        if model_dir.startswith('damo'):
            from modelscope.hub.snapshot_download import snapshot_download
            model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
        asr_train_config = os.path.join(model_dir, 'config.yaml')
        asr_model_file = os.path.join(model_dir, 'model.pb')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        json_file = os.path.join(model_dir, 'configuration.json')
            model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
        self.cache_dir = model_dir
        if mode is None:
            import json
            json_file = os.path.join(model_dir, 'configuration.json')
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                mode = config_data['model']['model_config']['mode']
                if config_data['task'] == "punctuation":
                    mode = config_data['model']['punc_model_config']['mode']
                else:
                    mode = config_data['model']['model_config']['mode']
        if mode.startswith('paraformer'):
            from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        elif mode.startswith('uniasr'):
            from funasr.tasks.asr import ASRTaskUniASR as ASRTask
            config = os.path.join(model_dir, 'config.yaml')
            model_file = os.path.join(model_dir, 'model.pb')
            cmvn_file = os.path.join(model_dir, 'am.mvn')
            model, asr_train_args = ASRTask.build_model_from_file(
                config, model_file, cmvn_file, 'cpu'
            )
            self.frontend = model.frontend
            self.export_config["feats_dim"] = 560
        elif mode.startswith('offline'):
            from funasr.tasks.vad import VADTask
            config = os.path.join(model_dir, 'vad.yaml')
            model_file = os.path.join(model_dir, 'vad.pb')
            cmvn_file = os.path.join(model_dir, 'vad.mvn')
            
        model, asr_train_args = ASRTask.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, 'cpu'
        )
        self.frontend = model.frontend
            model, vad_infer_args = VADTask.build_model_from_file(
                config, model_file, cmvn_file=cmvn_file, device='cpu'
            )
            self.export_config["feats_dim"] = 400
            self.frontend = model.frontend
        elif mode.startswith('punc'):
            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
            punc_train_config = os.path.join(model_dir, 'config.yaml')
            punc_model_file = os.path.join(model_dir, 'punc.pb')
            model, punc_train_args = PUNCTask.build_model_from_file(
                punc_train_config, punc_model_file, 'cpu'
            )
        elif mode.startswith('punc_VadRealtime'):
            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
            punc_train_config = os.path.join(model_dir, 'config.yaml')
            punc_model_file = os.path.join(model_dir, 'punc.pb')
            model, punc_train_args = PUNCTask.build_model_from_file(
                punc_train_config, punc_model_file, 'cpu'
            )
        self._export(model, tag_name)
            
@@ -198,7 +235,7 @@
        # model_script = torch.jit.script(model)
        model_script = model #torch.jit.trace(model)
        model_path = os.path.join(path, f'{model.model_name}.onnx')
        # if not os.path.exists(model_path):
        torch.onnx.export(
            model_script,
            dummy_input,
@@ -214,38 +251,46 @@
            from onnxruntime.quantization import QuantType, quantize_dynamic
            import onnx
            quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
            onnx_model = onnx.load(model_path)
            nodes = [n.name for n in onnx_model.graph.node]
            nodes_to_exclude = [m for m in nodes if 'output' in m]
            quantize_dynamic(
                model_input=model_path,
                model_output=quant_model_path,
                op_types_to_quantize=['MatMul'],
                per_channel=True,
                reduce_range=False,
                weight_type=QuantType.QUInt8,
                nodes_to_exclude=nodes_to_exclude,
            )
            if not os.path.exists(quant_model_path):
                onnx_model = onnx.load(model_path)
                nodes = [n.name for n in onnx_model.graph.node]
                nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m  or 'bias_decoder' in m]
                quantize_dynamic(
                    model_input=model_path,
                    model_output=quant_model_path,
                    op_types_to_quantize=['MatMul'],
                    per_channel=True,
                    reduce_range=False,
                    weight_type=QuantType.QUInt8,
                    nodes_to_exclude=nodes_to_exclude,
                )
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-name', type=str, required=True)
    # parser.add_argument('--model-name', type=str, required=True)
    parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
    parser.add_argument('--export-dir', type=str, required=True)
    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
    parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
    parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
    parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
    args = parser.parse_args()
    export_model = ModelExport(
        cache_dir=args.export_dir,
        onnx=args.type == 'onnx',
        device=args.device,
        quant=args.quantize,
        fallback_num=args.fallback_num,
        audio_in=args.audio_in,
        calib_num=args.calib_num,
        model_revision=args.model_revision,
    )
    export_model.export(args.model_name)
    for model_name in args.model_name:
        print("export model: {}".format(model_name))
        export_model.export(model_name)