[Quantization] model quantization for inference
| | |
| | | # assert torch_version > 1.9 |
| | | |
| | | class ASRModelExportParaformer: |
| | | def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): |
| | | def __init__( |
| | | self, cache_dir: Union[Path, str] = None, onnx: bool = True, quant: bool = True |
| | | ): |
| | | assert check_argument_types() |
| | | self.set_all_random_seed(0) |
| | | if cache_dir is None: |
| | |
| | | ) |
| | | print("output dir: {}".format(self.cache_dir)) |
| | | self.onnx = onnx |
| | | self.quant = quant |
| | | |
| | | |
| | | def _export( |
| | |
| | | print("output dir: {}".format(export_dir)) |
| | | |
| | | |
| | | def _torch_quantize(self, model): |
| | | from torch_quant.module import ModuleFilter |
| | | from torch_quant.observer import HistogramObserver |
| | | from torch_quant.quantizer import Backend, Quantizer |
| | | from funasr.export.models.modules.decoder_layer import DecoderLayerSANM |
| | | from funasr.export.models.modules.encoder_layer import EncoderLayerSANM |
| | | module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM]) |
| | | module_filter.exclude_op_types = [torch.nn.Conv1d] |
| | | quantizer = Quantizer( |
| | | module_filter=module_filter, |
| | | backend=Backend.FBGEMM, |
| | | act_ob_ctr=HistogramObserver, |
| | | ) |
| | | model.eval() |
| | | calib_model = quantizer.calib(model) |
| | | # run calibration data |
| | | # using dummy inputs for a example |
| | | dummy_input = model.get_dummy_inputs() |
| | | _ = calib_model(*dummy_input) |
| | | quant_model = quantizer.quantize(model) |
| | | return quant_model |
| | | |
| | | def _export_torchscripts(self, model, verbose, path, enc_size=None): |
| | | if enc_size: |
| | | dummy_input = model.get_dummy_inputs(enc_size) |
| | |
| | | # model_script = torch.jit.script(model) |
| | | model_script = torch.jit.trace(model, dummy_input) |
| | | model_script.save(os.path.join(path, f'{model.model_name}.torchscripts')) |
| | | |
| | | if self.quant: |
| | | quant_model = self._torch_quantize(model) |
| | | model_script = torch.jit.trace(quant_model, dummy_input) |
| | | model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts')) |
| | | |
| | | |
| | | def set_all_random_seed(self, seed: int): |
| | | random.seed(seed) |
| | |
| | | |
| | | # model_script = torch.jit.script(model) |
| | | model_script = model #torch.jit.trace(model) |
| | | model_path = os.path.join(path, f'{model.model_name}.onnx') |
| | | |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | | os.path.join(path, f'{model.model_name}.onnx'), |
| | | model_path, |
| | | verbose=verbose, |
| | | opset_version=14, |
| | | input_names=model.get_input_names(), |
| | | output_names=model.get_output_names(), |
| | | dynamic_axes=model.get_dynamic_axes() |
| | | ) |
| | | |
| | | if self.quant: |
| | | from onnxruntime.quantization import QuantType, quantize_dynamic |
| | | quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx') |
| | | quantize_dynamic( |
| | | model_input=model_path, |
| | | model_output=quant_model_path, |
| | | weight_type=QuantType.QUInt8, |
| | | ) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | |
| | | model_path = sys.argv[1] |
| | | output_dir = sys.argv[2] |
| | | onnx = sys.argv[3] |
| | | quant = sys.argv[4] |
| | | onnx = onnx.lower() |
| | | onnx = onnx == 'true' |
| | | quant = quant == 'true' |
| | | # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' |
| | | # output_dir = "../export" |
| | | export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx) |
| | | export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx, quant=quant) |
| | | export_model.export(model_path) |
| | | # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') |
| | | # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') |
| | |
| | | self.feed_forward = model.feed_forward |
| | | self.norm1 = model.norm1 |
| | | self.norm2 = model.norm2 |
| | | self.in_size = model.in_size |
| | | self.size = model.size |
| | | |
| | | def forward(self, x, mask): |
| | |
| | | residual = x |
| | | x = self.norm1(x) |
| | | x = self.self_attn(x, mask) |
| | | if x.size(2) == residual.size(2): |
| | | if self.in_size == self.size: |
| | | x = x + residual |
| | | residual = x |
| | | x = self.norm2(x) |
| | | x = self.feed_forward(x) |
| | | if x.size(2) == residual.size(2): |
| | | x = x + residual |
| | | x = x + residual |
| | | |
| | | return x, mask |
| | | |
| | |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |
| | | |
| | | def preprocess_for_attn(x, mask, cache, pad_fn): |
| | | x = x * mask |
| | | x = x.transpose(1, 2) |
| | | if cache is None: |
| | | x = pad_fn(x) |
| | | else: |
| | | x = torch.cat((cache[:, :, 1:], x), dim=2) |
| | | cache = x |
| | | return x, cache |
| | | |
| | | |
| | | import torch.fx |
| | | torch.fx.wrap('preprocess_for_attn') |
| | | |
| | | |
| | | class MultiHeadedAttentionSANMDecoder(nn.Module): |
| | | def __init__(self, model): |
| | | super().__init__() |
| | |
| | | self.attn = None |
| | | |
| | | def forward(self, inputs, mask, cache=None): |
| | | # b, t, d = inputs.size() |
| | | # mask = torch.reshape(mask, (b, -1, 1)) |
| | | inputs = inputs * mask |
| | | |
| | | x = inputs.transpose(1, 2) |
| | | if cache is None: |
| | | x = self.pad_fn(x) |
| | | else: |
| | | x = torch.cat((cache[:, :, 1:], x), dim=2) |
| | | cache = x |
| | | x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn) |
| | | x = self.fsmn_block(x) |
| | | x = x.transpose(1, 2) |
| | | |
| | |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | | context_layer = context_layer.view(new_context_layer_shape) |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |
| | | |