From d14bb0f843aea0aeb254a5b3c21d42d04e28765b Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 14 七月 2023 17:43:40 +0800
Subject: [PATCH] add conformer export

---
 funasr/export/models/e2e_asr_conformer.py       |   85 +++++-----------
 funasr/export/export_conformer.py               |  151 ++++++++++++++++++++++++++++++
 funasr/export/models/__init__.py                |    4 
 funasr/export/models/decoder/xformer_decoder.py |   50 ++++++---
 4 files changed, 215 insertions(+), 75 deletions(-)

diff --git a/funasr/export/export_conformer.py b/funasr/export/export_conformer.py
new file mode 100644
index 0000000..4980775
--- /dev/null
+++ b/funasr/export/export_conformer.py
@@ -0,0 +1,151 @@
+import json
+from typing import Union, Dict
+from pathlib import Path
+
+import os
+import logging
+import torch
+
+from funasr.export.models import get_model
+import numpy as np
+import random
+from funasr.utils.types import str2bool, str2triple_str
+# torch_version = float(".".join(torch.__version__.split(".")[:2]))
+# assert torch_version > 1.9
+
+class ModelExport:
+    def __init__(
+        self,
+        cache_dir: Union[Path, str] = None,
+        onnx: bool = True,
+        device: str = "cpu",
+        quant: bool = True,
+        fallback_num: int = 0,
+        audio_in: str = None,
+        calib_num: int = 200,
+        model_revision: str = None,
+    ):
+        self.set_all_random_seed(0)
+
+        self.cache_dir = cache_dir
+        self.export_config = dict(
+            feats_dim=560,
+            onnx=False,
+        )
+        
+        self.onnx = onnx
+        self.device = device
+        self.quant = quant
+        self.fallback_num = fallback_num
+        self.frontend = None
+        self.audio_in = audio_in
+        self.calib_num = calib_num
+        self.model_revision = model_revision
+
+    def _export(
+        self,
+        model,
+        model_dir: str = None,
+        verbose: bool = False,
+    ):
+
+        export_dir = model_dir
+        os.makedirs(export_dir, exist_ok=True)
+
+        self.export_config["model_name"] = "model"
+        model = get_model(
+            model,
+            self.export_config,
+        )
+        model.eval()
+
+        if self.onnx:
+            self._export_onnx(model, verbose, export_dir)
+
+        print("output dir: {}".format(export_dir))
+
+    def _export_onnx(self, model, verbose, path):
+        model._export_onnx(verbose, path)
+
+    def set_all_random_seed(self, seed: int):
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.random.manual_seed(seed)
+
+    def parse_audio_in(self, audio_in):
+        
+        wav_list, name_list = [], []
+        if audio_in.endswith(".scp"):
+            f = open(audio_in, 'r')
+            lines = f.readlines()[:self.calib_num]
+            for line in lines:
+                name, path = line.strip().split()
+                name_list.append(name)
+                wav_list.append(path)
+        else:
+            wav_list = [audio_in,]
+            name_list = ["test",]
+        return wav_list, name_list
+    
+    def load_feats(self, audio_in: str = None):
+        import torchaudio
+
+        wav_list, name_list = self.parse_audio_in(audio_in)
+        feats = []
+        feats_len = []
+        for line in wav_list:
+            path = line.strip()
+            waveform, sampling_rate = torchaudio.load(path)
+            if sampling_rate != self.frontend.fs:
+                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
+                                                          new_freq=self.frontend.fs)(waveform)
+            fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
+            feats.append(fbank)
+            feats_len.append(fbank_len)
+        return feats, feats_len
+
+    def export(self,
+               mode: str = None,
+               ):
+
+        if mode.startswith('conformer'):
+            from funasr.tasks.asr import ASRTask
+            config = os.path.join(model_dir, 'config.yaml')
+            model_file = os.path.join(model_dir, 'model.pb')
+            cmvn_file = os.path.join(model_dir, 'am.mvn')
+            model, asr_train_args = ASRTask.build_model_from_file(
+                config, model_file, cmvn_file, 'cpu'
+            )
+            self.frontend = model.frontend
+            self.export_config["feats_dim"] = 560
+
+        self._export(model, self.cache_dir)
+
+if __name__ == '__main__':
+    import argparse
+    parser = argparse.ArgumentParser()
+    # parser.add_argument('--model-name', type=str, required=True)
+    parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
+    parser.add_argument('--export-dir', type=str, required=True)
+    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+    parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+    parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+    parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
+    args = parser.parse_args()
+
+    export_model = ModelExport(
+        cache_dir=args.export_dir,
+        onnx=args.type == 'onnx',
+        device=args.device,
+        quant=args.quantize,
+        fallback_num=args.fallback_num,
+        audio_in=args.audio_in,
+        calib_num=args.calib_num,
+        model_revision=args.model_revision,
+    )
+    for model_name in args.model_name:
+        print("export model: {}".format(model_name))
+        export_model.export(model_name)
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 0e3a782..6177119 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,6 +1,8 @@
 from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
 from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
 from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
+from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
+
 from funasr.models.e2e_vad import E2EVadModel
 from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
 from funasr.models.target_delay_transformer import TargetDelayTransformer
@@ -14,6 +16,8 @@
         return BiCifParaformer_export(model, **export_config)
     elif isinstance(model, Paraformer):
         return Paraformer_export(model, **export_config)
+    elif isinstance(model, Conformer_export):
+        return Conformer_export(model, **export_config)
     elif isinstance(model, E2EVadModel):
         return E2EVadModel_export(model, **export_config)
     elif isinstance(model, PunctuationModel):
diff --git a/funasr/export/models/decoder/xformer_decoder.py b/funasr/export/models/decoder/xformer_decoder.py
index 29837e1..15199aa 100644
--- a/funasr/export/models/decoder/xformer_decoder.py
+++ b/funasr/export/models/decoder/xformer_decoder.py
@@ -13,16 +13,24 @@
 from funasr.export.utils.torch_function import MakePadMask, subsequent_mask
 
 class XformerDecoder(nn.Module):
-    def __init__(self, model, max_seq_len=512, **kwargs):
+    def __init__(self,
+                 model,
+                 max_seq_len = 512,
+                 model_name = 'decoder',
+                 onnx: bool = True,):
         super().__init__()
         self.embed = Embedding(model.embed, max_seq_len)
         self.model = model
-        self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = subsequent_mask(max_seq_len, flip=False)
+
         if isinstance(self.model.decoders[0].self_attn, MultiHeadedAttention):
             self.num_heads = self.model.decoders[0].self_attn.h
             self.hidden_size = self.model.decoders[0].self_attn.linear_out.out_features
 
-        # replace multihead attention module into customized module.
+        # replace multi-head attention module into customized module.
         for i, d in enumerate(self.model.decoders):
             # d is DecoderLayer
             if isinstance(d.self_attn, MultiHeadedAttention):
@@ -31,17 +39,23 @@
                 d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
             self.model.decoders[i] = OnnxDecoderLayer(d)
 
-        self.model_name = "xformer_decoder"
+        self.model_name = model_name
 
     def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
         if len(mask.shape) == 2:
-            mask = mask[:, None, None, :]
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
         elif len(mask.shape) == 3:
-            mask = mask[:, None, :]
-        mask = 1 - mask
-        return mask * -10000.0
+            mask_4d_bhlt = 1 - mask[:, None, :]
 
-    def forward(self, tgt, memory, cache):
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(self,
+                tgt,
+                memory,
+                cache):
+
         mask = subsequent_mask(tgt.size(-1)).unsqueeze(0)  # (B, T)
 
         x = self.embed(tgt)
@@ -63,33 +77,37 @@
 
     def get_dummy_inputs(self, enc_size):
         tgt = torch.LongTensor([0]).unsqueeze(0)
-        enc_out = torch.randn(1, 100, enc_size)
+        memory = torch.randn(1, 100, enc_size)
+        cache_num = len(self.model.decoders)
         cache = [
             torch.zeros((1, 1, self.model.decoders[0].size))
-            for _ in range(len(self.model.decoders))
+            for _ in range(cache_num)
         ]
-        return (tgt, enc_out, cache)
+        return (tgt, memory, cache)
 
     def is_optimizable(self):
         return True
 
     def get_input_names(self):
+        cache_num = len(self.model.decoders)
         return ["tgt", "memory"] + [
-            "cache_%d" % i for i in range(len(self.model.decoders))
+            "cache_%d" % i for i in range(cache_num)
         ]
 
     def get_output_names(self):
-        return ["y"] + ["out_cache_%d" % i for i in range(len(self.model.decoders))]
+        cache_num = len(self.model.decoders)
+        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
 
     def get_dynamic_axes(self):
         ret = {
             "tgt": {0: "tgt_batch", 1: "tgt_length"},
             "memory": {0: "memory_batch", 1: "memory_length"},
         }
+        cache_num = len(self.model.decoders)
         ret.update(
             {
-                "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
-                for d in range(len(self.model.decoders))
+                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
+                for d in range(cache_num)
             }
         )
         return ret
diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
index 69907fb..45feda5 100644
--- a/funasr/export/models/e2e_asr_conformer.py
+++ b/funasr/export/models/e2e_asr_conformer.py
@@ -1,3 +1,4 @@
+import os
 import logging
 import torch
 import torch.nn as nn
@@ -5,6 +6,7 @@
 from funasr.export.utils.torch_function import MakePadMask
 from funasr.export.utils.torch_function import sequence_mask
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
 from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
 
@@ -18,7 +20,6 @@
             model,
             max_seq_len=512,
             feats_dim=560,
-            output_size=2048,
             model_name='model',
             **kwargs,
     ):
@@ -32,71 +33,37 @@
             self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
         
         self.feats_dim = feats_dim
-        self.output_size = output_size
         self.model_name = model_name
 
         if onnx:
             self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
         else:
             self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-    def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-    ):
-        # a. To device
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
-        # batch = to_device(batch, device=self.device)
-    
-        enc, enc_len = self.encoder(**batch)
-        mask = self.make_pad_mask(enc_len)[:, None, :]
 
-        # fill the decoder input
-        enc_size = self.encoder.output_size
-        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
-        cache_num = len(self.model.decoder)
-        cache = [
-            torch.zeros((1, self.decoder.size, self.decoder.self_attn.kernel_size))
-            for _ in range(cache_num)
-        ]
+    def _export_model(self, model, verbose, path):
+        dummy_input = model.get_dummy_inputs()
+        model_script = model
+        model_path = os.path.join(path, f'{model.model_name}.onnx')
+        if not os.path.exists(model_path):
+            torch.onnx.export(
+                model_script,
+                dummy_input,
+                model_path,
+                verbose=verbose,
+                opset_version=14,
+                input_names=model.get_input_names(),
+                output_names=model.get_output_names(),
+                dynamic_axes=model.get_dynamic_axes()
+            )
 
-        decoder_out, olens = self.decoder(enc, enc_len, pre_acoustic_embeds, cache)
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        # sample_ids = decoder_out.argmax(dim=-1)
+    def _export_encoder_onnx(self, verbose, path):
+        model_encoder = self.encoder
+        self._export_model(model_encoder, verbose, path)
 
-        return decoder_out, olens
+    def _export_decoder_onnx(self, verbose, path):
+        model_decoder = self.decoder
+        self._export_model(model_decoder, verbose, path)
 
-    def get_dummy_inputs(self):
-        speech = torch.randn(2, 30, self.feats_dim)
-        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
-        return (speech, speech_lengths)
-
-    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
-        import numpy as np
-        fbank = np.loadtxt(txt_file)
-        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
-        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
-        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
-        return (speech, speech_lengths)
-
-    def get_input_names(self):
-        return ['speech', 'speech_lengths']
-
-    def get_output_names(self):
-        return ['logits', 'token_num']
-
-    def get_dynamic_axes(self):
-        return {
-            'speech': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'speech_lengths': {
-                0: 'batch_size',
-            },
-            'logits': {
-                0: 'batch_size',
-                1: 'logits_length'
-            },
-        }
+    def _export_onnx(self, verbose, path):
+        self._export_encoder_onnx(verbose, path)
+        self._export_decoder_onnx(verbose, path)
\ No newline at end of file

--
Gitblit v1.9.1