游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import logging
import torch
import torch.nn as nn
 
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
 
class Conformer(nn.Module):
    """
    export conformer into onnx format
    """
 
    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        if isinstance(model.encoder, ConformerEncoder):
            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
        elif isinstance(model.decoder, TransformerDecoder):
            self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
        
        self.feats_dim = feats_dim
        self.model_name = model_name
 
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
 
    def _export_model(self, model, verbose, path):
        dummy_input = model.get_dummy_inputs()
        model_script = model
        model_path = os.path.join(path, f'{model.model_name}.onnx')
        if not os.path.exists(model_path):
            torch.onnx.export(
                model_script,
                dummy_input,
                model_path,
                verbose=verbose,
                opset_version=14,
                input_names=model.get_input_names(),
                output_names=model.get_output_names(),
                dynamic_axes=model.get_dynamic_axes()
            )
 
    def _export_encoder_onnx(self, verbose, path):
        model_encoder = self.encoder
        self._export_model(model_encoder, verbose, path)
 
    def _export_decoder_onnx(self, verbose, path):
        model_decoder = self.decoder
        self._export_model(model_decoder, verbose, path)
 
    def _export_onnx(self, verbose, path):
        self._export_encoder_onnx(verbose, path)
        self._export_decoder_onnx(verbose, path)