huangmingming
2023-03-13 49c00a7d6cb9c05d4bd0bb0fc8b59a2eed4b8950
funasr/export/export_model.py
@@ -7,10 +7,12 @@
import logging
import torch
from funasr.bin.asr_inference_paraformer import Speech2Text
from funasr.export.models import get_model
import numpy as np
import random
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
class ASRModelExportParaformer:
    def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
@@ -30,7 +32,7 @@
    def _export(
        self,
        model: Speech2Text,
        model,
        tag_name: str = None,
        verbose: bool = False,
    ):
@@ -44,6 +46,7 @@
            model,
            self.export_config,
        )
        model.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(model, verbose, export_dir)
@@ -57,7 +60,7 @@
        if enc_size:
            dummy_input = model.get_dummy_inputs(enc_size)
        else:
            dummy_input = model.get_dummy_inputs_txt()
            dummy_input = model.get_dummy_inputs()
        # model_script = torch.jit.script(model)
        model_script = torch.jit.trace(model, dummy_input)
@@ -85,9 +88,9 @@
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                mode = config_data['model']['model_config']['mode']
        if mode == 'paraformer':
        if mode.startswith('paraformer'):
            from funasr.tasks.asr import ASRTaskParaformer as ASRTask
        elif mode == 'uniasr':
        elif mode.startswith('uniasr'):
            from funasr.tasks.asr import ASRTaskUniASR as ASRTask
            
        model, asr_train_args = ASRTask.build_model_from_file(
@@ -110,12 +113,13 @@
            dummy_input,
            os.path.join(path, f'{model.model_name}.onnx'),
            verbose=verbose,
            opset_version=12,
            opset_version=14,
            input_names=model.get_input_names(),
            output_names=model.get_output_names(),
            dynamic_axes=model.get_dynamic_axes()
        )
if __name__ == '__main__':
    import sys