游雁
2023-03-13 fc08b62d05723cdc1ce021bb8ba044ca014fb1f7
funasr/export/export_model.py
@@ -7,10 +7,12 @@
import logging
import torch
from funasr.bin.asr_inference_paraformer import Speech2Text
from funasr.export.models import get_model
import numpy as np
import random
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
class ASRModelExportParaformer:
    def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
@@ -30,7 +32,7 @@
    def _export(
        self,
        model: Speech2Text,
        model,
        tag_name: str = None,
        verbose: bool = False,
    ):
@@ -58,7 +60,7 @@
        if enc_size:
            dummy_input = model.get_dummy_inputs(enc_size)
        else:
            dummy_input = model.get_dummy_inputs_txt()
            dummy_input = model.get_dummy_inputs()
        # model_script = torch.jit.script(model)
        model_script = torch.jit.trace(model, dummy_input)
@@ -111,12 +113,13 @@
            dummy_input,
            os.path.join(path, f'{model.model_name}.onnx'),
            verbose=verbose,
            opset_version=12,
            opset_version=14,
            input_names=model.get_input_names(),
            output_names=model.get_output_names(),
            dynamic_axes=model.get_dynamic_axes()
        )
if __name__ == '__main__':
    import sys