游雁
2023-06-20 43ad2c35634a3ed2a7a46bd7e3afd147934b1c48
funasr/export/export_model.py
@@ -229,34 +229,35 @@
        # model_script = torch.jit.script(model)
        model_script = model #torch.jit.trace(model)
        model_path = os.path.join(path, f'{model.model_name}.onnx')
        torch.onnx.export(
            model_script,
            dummy_input,
            model_path,
            verbose=verbose,
            opset_version=14,
            input_names=model.get_input_names(),
            output_names=model.get_output_names(),
            dynamic_axes=model.get_dynamic_axes()
        )
        if not os.path.exists(model_path):
            torch.onnx.export(
                model_script,
                dummy_input,
                model_path,
                verbose=verbose,
                opset_version=14,
                input_names=model.get_input_names(),
                output_names=model.get_output_names(),
                dynamic_axes=model.get_dynamic_axes()
            )
        if self.quant:
            from onnxruntime.quantization import QuantType, quantize_dynamic
            import onnx
            quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
            onnx_model = onnx.load(model_path)
            nodes = [n.name for n in onnx_model.graph.node]
            nodes_to_exclude = [m for m in nodes if 'output' in m]
            quantize_dynamic(
                model_input=model_path,
                model_output=quant_model_path,
                op_types_to_quantize=['MatMul'],
                per_channel=True,
                reduce_range=False,
                weight_type=QuantType.QUInt8,
                nodes_to_exclude=nodes_to_exclude,
            )
            if not os.path.exists(quant_model_path):
                onnx_model = onnx.load(model_path)
                nodes = [n.name for n in onnx_model.graph.node]
                nodes_to_exclude = [m for m in nodes if 'output' in m]
                quantize_dynamic(
                    model_input=model_path,
                    model_output=quant_model_path,
                    op_types_to_quantize=['MatMul'],
                    per_channel=True,
                    reduce_range=False,
                    weight_type=QuantType.QUInt8,
                    nodes_to_exclude=nodes_to_exclude,
                )
if __name__ == '__main__':