From 8a620a5a36df782e1f9e8cc56064d5dc6a1330b5 Mon Sep 17 00:00:00 2001
From: wanchen.swc <wanchen.swc@alibaba-inc.com>
Date: 星期三, 15 三月 2023 15:31:31 +0800
Subject: [PATCH] [Quantization] automatic mixed precision quantization
---
funasr/export/export_model.py | 57 ++++++++++++++++++++++++++++++++++++---------------------
1 files changed, 36 insertions(+), 21 deletions(-)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 7370c3c..beb1efe 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -16,7 +16,11 @@
class ASRModelExportParaformer:
def __init__(
- self, cache_dir: Union[Path, str] = None, onnx: bool = True, quant: bool = True
+ self,
+ cache_dir: Union[Path, str] = None,
+ onnx: bool = True,
+ quant: bool = True,
+ fallback_num: int = 0,
):
assert check_argument_types()
self.set_all_random_seed(0)
@@ -31,6 +35,7 @@
print("output dir: {}".format(self.cache_dir))
self.onnx = onnx
self.quant = quant
+ self.fallback_num = fallback_num
def _export(
@@ -60,8 +65,12 @@
def _torch_quantize(self, model):
+ def _run_calibration_data(m):
+ # using dummy inputs for a example
+ dummy_input = model.get_dummy_inputs()
+ m(*dummy_input)
+
from torch_quant.module import ModuleFilter
- from torch_quant.observer import HistogramObserver
from torch_quant.quantizer import Backend, Quantizer
from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
@@ -70,16 +79,20 @@
quantizer = Quantizer(
module_filter=module_filter,
backend=Backend.FBGEMM,
- act_ob_ctr=HistogramObserver,
)
model.eval()
calib_model = quantizer.calib(model)
- # run calibration data
- # using dummy inputs for a example
- dummy_input = model.get_dummy_inputs()
- _ = calib_model(*dummy_input)
+ _run_calibration_data(calib_model)
+ if self.fallback_num > 0:
+ # perform automatic mixed precision quantization
+ amp_model = quantizer.amp(model)
+ _run_calibration_data(amp_model)
+ quantizer.fallback(amp_model, num=self.fallback_num)
+ print('Fallback layers:')
+ print('\n'.join(quantizer.module_filter.exclude_names))
quant_model = quantizer.quantize(model)
return quant_model
+
def _export_torchscripts(self, model, verbose, path, enc_size=None):
if enc_size:
@@ -170,17 +183,19 @@
if __name__ == '__main__':
- import sys
-
- model_path = sys.argv[1]
- output_dir = sys.argv[2]
- onnx = sys.argv[3]
- quant = sys.argv[4]
- onnx = onnx.lower()
- onnx = onnx == 'true'
- quant = quant == 'true'
- # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
- # output_dir = "../export"
- export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx, quant=quant)
- export_model.export(model_path)
- # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model-name', type=str, required=True)
+ parser.add_argument('--export-dir', type=str, required=True)
+ parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+ parser.add_argument('--quantize', action='store_true', help='export quantized model')
+ parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+ args = parser.parse_args()
+
+ export_model = ASRModelExportParaformer(
+ cache_dir=args.export_dir,
+ onnx=args.type == 'onnx',
+ quant=args.quantize,
+ fallback_num=args.fallback_num,
+ )
+ export_model.export(args.model_name)
--
Gitblit v1.9.1