paraformer onnx fp16导出方案 (#2264)
* onnx fp16模型
* paraformer-offline [fp32 fp16 onnx-gpu]
* paraformer-offline [fp32 fp16 onnx-gpu]
* Update export.py
---------
Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>
| New file |
| | |
| | | # method2, inference from local path
|
| | | from funasr import AutoModel
|
| | |
|
| | | model = AutoModel(
|
| | | model="/raid/t3cv/wangch/WORK_SAPCE/ASR/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
| | | )
|
| | |
|
| | | res = model.export(type="onnx", quantize=False, opset_version=13, device='cuda') # fp32 onnx-gpu
|
| | | # res = model.export(type="onnx_fp16", quantize=False, opset_version=13, device='cuda') # fp16 onnx-gpu
|
| | |
| | | hidden, alphas, token_num, mask=None
|
| | | )
|
| | |
|
| | | acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
|
| | | if target_length is None and self.tail_threshold > 0.0:
|
| | | token_num_int = torch.max(token_num).type(torch.int32).item()
|
| | | acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| | |
| | | mask = mask.transpose(-1, -2).float()
|
| | | mask = mask.squeeze(-1)
|
| | | hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
| | | acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
|
| | | acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)
|
| | |
|
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
|
| | |
| | | fires = fires + prefix_sum - prefix_sum_floor
|
| | |
|
| | | # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
|
| | |
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|
| | |
| | | # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)
|
| | | prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1)
|
| | |
|
| | | frames = prefix_sum_hidden[fire_idxs]
|
| | | shift_frames = torch.roll(frames, 1, dims=0)
|
| | |
| | |
|
| | | remains = fires - torch.floor(fires)
|
| | | # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
|
| | | remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs]
|
| | |
|
| | | shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
|
| | | shift_remain_frames[shift_batch_idxs] = 0
|
| | |
| | | 0: "batch_size", |
| | | }, |
| | | "logits": {0: "batch_size", 1: "logits_length"}, |
| | | "token_num": {0: "batch_size"} |
| | | } |
| | | |
| | | |
| | |
| | | import os |
| | | import torch |
| | | import functools |
| | | import onnx |
| | | from onnxconverter_common import float16 |
| | | |
| | | import warnings |
| | | warnings.filterwarnings("ignore") |
| | | |
| | | |
| | | |
| | | def export( |
| | |
| | | if hasattr(m, "encoder") and hasattr(m, "decoder"): |
| | | _bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True) |
| | | else: |
| | | print(f"export_dir: {export_dir}") |
| | | _torchscripts(m, path=export_dir, device="cuda") |
| | | print("output dir: {}".format(export_dir)) |
| | | |
| | | |
| | | elif type=='onnx_fp16': |
| | | assert ( |
| | | torch.cuda.is_available() |
| | | ), "Currently onnx_fp16 optimization for FunASR only supports GPU" |
| | | |
| | | if hasattr(m, "encoder") and hasattr(m, "decoder"): |
| | | _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True) |
| | | |
| | | return export_dir |
| | | |
| | |
| | | ): |
| | | |
| | | dummy_input = model.export_dummy_inputs() |
| | | dummy_input = (dummy_input[0].to("cuda"), dummy_input[1].to("cuda")) |
| | | |
| | | |
| | | verbose = kwargs.get("verbose", False) |
| | | |
| | |
| | | dummy_input, |
| | | model_path, |
| | | verbose=verbose, |
| | | do_constant_folding=True, |
| | | opset_version=opset_version, |
| | | input_names=model.export_input_names(), |
| | | output_names=model.export_output_names(), |
| | |
| | | |
| | | # Rescale encoder modules |
| | | fp16_scale = int(2 * absmax // 65536) |
| | | print(f"rescale encoder modules with factor={fp16_scale}") |
| | | print(f"rescale encoder modules with factor={fp16_scale}\n\n") |
| | | model.encoder.model.encoders0.register_forward_pre_hook( |
| | | functools.partial(_rescale_input_hook, scale=fp16_scale), |
| | | ) |
| | |
| | | model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs)) |
| | | model_script = torch.jit.trace(model, input_data) |
| | | model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript")) |
| | | |
| | | |
| | | |
| | | def _onnx_opt_for_encdec(model, path, enable_fp16): |
| | | |
| | | # Get input data |
| | | # TODO: better to use real data |
| | | input_data = model.export_dummy_inputs() |
| | | |
| | | if isinstance(input_data, torch.Tensor): |
| | | input_data = input_data.cuda() |
| | | else: |
| | | input_data = tuple([i.cuda() for i in input_data]) |
| | | |
| | | # Get input data for decoder module |
| | | decoder_inputs = list() |
| | | |
| | | def get_input_hook(m, x): |
| | | decoder_inputs.extend(list(x)) |
| | | |
| | | hook = model.decoder.register_forward_pre_hook(get_input_hook) |
| | | model = model.cuda() |
| | | model(*input_data) |
| | | hook.remove() |
| | | |
| | | # Prevent FP16 overflow |
| | | if enable_fp16: |
| | | _rescale_encoder_model(model, input_data) |
| | | |
| | | fp32_model_path = f"{path}/{model.export_name}_hook.onnx" |
| | | print("*" * 50) |
| | | print(f"[_onnx_opt_for_encdec(fp32)]: {fp32_model_path}\n\n") |
| | | if not os.path.exists(fp32_model_path): |
| | | |
| | | torch.onnx.export( |
| | | model, |
| | | input_data, |
| | | fp32_model_path, |
| | | verbose=False, |
| | | do_constant_folding=True, |
| | | opset_version=13, |
| | | input_names=model.export_input_names(), |
| | | output_names=model.export_output_names(), |
| | | dynamic_axes=model.export_dynamic_axes(), |
| | | ) |
| | | |
| | | |
| | | # fp32 to fp16 |
| | | fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx" |
| | | print("*" * 50) |
| | | print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n") |
| | | if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path): |
| | | fp32_onnx_model = onnx.load(fp32_model_path) |
| | | fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True) |
| | | onnx.save( |
| | | fp16_onnx_model, fp16_model_path |
| | | ) |