维石
2024-05-28 e7351db81b3bfc4000633eca274c46893d68f64e
update export
5个文件已修改
70 ■■■■■ 已修改文件
examples/industrial_data_pretraining/paraformer/export.py 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/export_meta.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/export_meta.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/export_utils.py 38 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.py
@@ -13,16 +13,16 @@
    model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
)
res = model.export(type="onnx", quantize=False)
res = model.export(type="torchscript", quantize=False)
print(res)
# method2, inference from local path
from funasr import AutoModel
# # method2, inference from local path
# from funasr import AutoModel
model = AutoModel(
    model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
)
# model = AutoModel(
#     model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
# )
res = model.export(type="onnx", quantize=False)
print(res)
# res = model.export(type="onnx", quantize=False)
# print(res)
funasr/auto/auto_model.py
@@ -580,12 +580,6 @@
        )
        with torch.no_grad():
            if type == "onnx":
                export_dir = export_utils.export_onnx(model=model, data_in=data_list, **kwargs)
            else:
                export_dir = export_utils.export_torchscripts(
                    model=model, data_in=data_list, **kwargs
                )
            export_dir = export_utils.export(model=model, data_in=data_list,  **kwargs)
        return export_dir
funasr/models/paraformer/export_meta.py
@@ -31,6 +31,7 @@
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    model.export_name = 'model'
    return model
funasr/models/seaco_paraformer/export_meta.py
@@ -109,7 +109,9 @@
    backbone_model.export_dynamic_axes = types.MethodType(
        export_backbone_dynamic_axes, backbone_model
    )
    backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)
    embedder_model.export_name = "model_eb"
    backbone_model.export_name = "model_bb"
    return backbone_model, embedder_model
@@ -192,6 +194,3 @@
        "pre_acoustic_embeds": {1: "feats_length1"},
    }
def export_backbone_name(self):
    return "model.onnx"
funasr/utils/export_utils.py
@@ -2,7 +2,7 @@
import torch
def export_onnx(model, data_in=None, quantize: bool = False, opset_version: int = 14, **kwargs):
def export(model, data_in=None, quantize: bool = False, opset_version: int = 14, type='onnx', **kwargs):
    model_scripts = model.export(**kwargs)
    export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
    os.makedirs(export_dir, exist_ok=True)
@@ -11,14 +11,20 @@
        model_scripts = (model_scripts,)
    for m in model_scripts:
        m.eval()
        _onnx(
            m,
            data_in=data_in,
            quantize=quantize,
            opset_version=opset_version,
            export_dir=export_dir,
            **kwargs
        )
        if type == 'onnx':
            _onnx(
                m,
                data_in=data_in,
                quantize=quantize,
                opset_version=opset_version,
                export_dir=export_dir,
                **kwargs
            )
        elif type == 'torchscript':
            _torchscripts(
                m,
                path=export_dir,
            )
        print("output dir: {}".format(export_dir))
    return export_dir
@@ -37,7 +43,7 @@
    verbose = kwargs.get("verbose", False)
    export_name = model.export_name() if hasattr(model, "export_name") else "model.onnx"
    export_name = model.export_name + '.onnx'
    model_path = os.path.join(export_dir, export_name)
    torch.onnx.export(
        model,
@@ -70,3 +76,15 @@
                weight_type=QuantType.QUInt8,
                nodes_to_exclude=nodes_to_exclude,
            )
def _torchscripts(model, path, device='cpu'):
    dummy_input = model.export_dummy_inputs()
    if device == 'cuda':
        model = model.cuda()
        dummy_input = tuple([i.cuda() for i in dummy_input])
    # model_script = torch.jit.script(model)
    model_script = torch.jit.trace(model, dummy_input)
    model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))