import os
|
import logging
|
import torch
|
import torch.nn as nn
|
|
from funasr.export.utils.torch_function import MakePadMask
|
from funasr.export.utils.torch_function import sequence_mask
|
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
|
from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
|
|
class Conformer(nn.Module):
|
"""
|
export conformer into onnx format
|
"""
|
|
def __init__(
|
self,
|
model,
|
max_seq_len=512,
|
feats_dim=560,
|
model_name='model',
|
**kwargs,
|
):
|
super().__init__()
|
onnx = False
|
if "onnx" in kwargs:
|
onnx = kwargs["onnx"]
|
if isinstance(model.encoder, ConformerEncoder):
|
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
|
elif isinstance(model.decoder, TransformerDecoder):
|
self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
|
|
self.feats_dim = feats_dim
|
self.model_name = model_name
|
|
if onnx:
|
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
else:
|
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
|
def _export_model(self, model, verbose, path):
|
dummy_input = model.get_dummy_inputs()
|
model_script = model
|
model_path = os.path.join(path, f'{model.model_name}.onnx')
|
if not os.path.exists(model_path):
|
torch.onnx.export(
|
model_script,
|
dummy_input,
|
model_path,
|
verbose=verbose,
|
opset_version=14,
|
input_names=model.get_input_names(),
|
output_names=model.get_output_names(),
|
dynamic_axes=model.get_dynamic_axes()
|
)
|
|
def _export_encoder_onnx(self, verbose, path):
|
model_encoder = self.encoder
|
self._export_model(model_encoder, verbose, path)
|
|
def _export_decoder_onnx(self, verbose, path):
|
model_decoder = self.decoder
|
self._export_model(model_decoder, verbose, path)
|
|
def _export_onnx(self, verbose, path):
|
self._export_encoder_onnx(verbose, path)
|
self._export_decoder_onnx(verbose, path)
|