游雁
2023-05-12 d25d0942f98ba5d879a2a39a8c30201c6496f3ae
onnx export funasr_onnx
4个文件已修改
81 ■■■■ 已修改文件
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py 30 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -32,7 +32,7 @@
                 plot_timestamp_to: str = "",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir=None
                 cache_dir: str = None
                 ):
        if not Path(model_dir).exists():
@@ -41,6 +41,12 @@
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            from funasr.export.export_model import ModelExport
            export_model = ModelExport(
                cache_dir=cache_dir,
@@ -50,11 +56,6 @@
            )
            export_model.export(model_dir)
            
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -24,15 +24,32 @@
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None,
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                    model_dir)
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            from funasr.export.export_model import ModelExport
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
        config_file = os.path.join(model_dir, 'punc.yaml')
        config = read_yaml(config_file)
@@ -135,9 +152,10 @@
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
    def __call__(self, text: str, param_dict: map, split_size=20):
        cache_key = "cache"
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -271,4 +271,5 @@
    logger.addHandler(sh)
    logger_initialized[name] = True
    logger.propagate = False
    logging.basicConfig(level=logging.ERROR)
    return logger
funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -31,14 +31,30 @@
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 max_end_sil: int = None,
                 cache_dir: str = None
                 ):
        
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                    model_dir)
        
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            from funasr.export.export_model import ModelExport
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
        config_file = os.path.join(model_dir, 'vad.yaml')
        cmvn_file = os.path.join(model_dir, 'vad.mvn')
        config = read_yaml(config_file)
@@ -172,14 +188,29 @@
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 max_end_sil: int = None,
                 cache_dir: str = None
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                    model_dir)
        
        model_file = os.path.join(model_dir, 'model.onnx')
        if quantize:
            model_file = os.path.join(model_dir, 'model_quant.onnx')
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            from funasr.export.export_model import ModelExport
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
        config_file = os.path.join(model_dir, 'vad.yaml')
        cmvn_file = os.path.join(model_dir, 'vad.mvn')
        config = read_yaml(config_file)