From 8a620a5a36df782e1f9e8cc56064d5dc6a1330b5 Mon Sep 17 00:00:00 2001
From: wanchen.swc <wanchen.swc@alibaba-inc.com>
Date: 星期三, 15 三月 2023 15:31:31 +0800
Subject: [PATCH] [Quantization] automatic mixed precision quantization

---
 funasr/export/export_model.py |   57 ++++++++++++++++++++++++++++++++++++---------------------
 1 files changed, 36 insertions(+), 21 deletions(-)

diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 7370c3c..beb1efe 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -16,7 +16,11 @@
 
 class ASRModelExportParaformer:
     def __init__(
-        self, cache_dir: Union[Path, str] = None, onnx: bool = True, quant: bool = True
+        self,
+        cache_dir: Union[Path, str] = None,
+        onnx: bool = True,
+        quant: bool = True,
+        fallback_num: int = 0,
     ):
         assert check_argument_types()
         self.set_all_random_seed(0)
@@ -31,6 +35,7 @@
         print("output dir: {}".format(self.cache_dir))
         self.onnx = onnx
         self.quant = quant
+        self.fallback_num = fallback_num
         
 
     def _export(
@@ -60,8 +65,12 @@
 
 
     def _torch_quantize(self, model):
+        def _run_calibration_data(m):
+            # using dummy inputs for a example
+            dummy_input = model.get_dummy_inputs()
+            m(*dummy_input)
+
         from torch_quant.module import ModuleFilter
-        from torch_quant.observer import HistogramObserver
         from torch_quant.quantizer import Backend, Quantizer
         from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
         from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
@@ -70,16 +79,20 @@
         quantizer = Quantizer(
             module_filter=module_filter,
             backend=Backend.FBGEMM,
-            act_ob_ctr=HistogramObserver,
         )
         model.eval()
         calib_model = quantizer.calib(model)
-        # run calibration data
-        # using dummy inputs for a example
-        dummy_input = model.get_dummy_inputs()
-        _ = calib_model(*dummy_input)
+        _run_calibration_data(calib_model)
+        if self.fallback_num > 0:
+            # perform automatic mixed precision quantization
+            amp_model = quantizer.amp(model)
+            _run_calibration_data(amp_model)
+            quantizer.fallback(amp_model, num=self.fallback_num)
+            print('Fallback layers:')
+            print('\n'.join(quantizer.module_filter.exclude_names))
         quant_model = quantizer.quantize(model)
         return quant_model
+
 
     def _export_torchscripts(self, model, verbose, path, enc_size=None):
         if enc_size:
@@ -170,17 +183,19 @@
 
 
 if __name__ == '__main__':
-    import sys
-    
-    model_path = sys.argv[1]
-    output_dir = sys.argv[2]
-    onnx = sys.argv[3]
-    quant = sys.argv[4]
-    onnx = onnx.lower()
-    onnx = onnx == 'true'
-    quant = quant == 'true'
-    # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
-    # output_dir = "../export"
-    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx, quant=quant)
-    export_model.export(model_path)
-    # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model-name', type=str, required=True)
+    parser.add_argument('--export-dir', type=str, required=True)
+    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+    parser.add_argument('--quantize', action='store_true', help='export quantized model')
+    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+    args = parser.parse_args()
+
+    export_model = ASRModelExportParaformer(
+        cache_dir=args.export_dir,
+        onnx=args.type == 'onnx',
+        quant=args.quantize,
+        fallback_num=args.fallback_num,
+    )
+    export_model.export(args.model_name)

--
Gitblit v1.9.1