From c542eacb0aadcbc49c63db40429fca4e08f807a4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 21 七月 2023 10:27:35 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/export/models/language_models/transformer.py                                      |  110 ++++
 funasr/modules/vgg2l.py                                                                  |   92 ++++
 egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.sh    |  103 ++++
 funasr/export/models/language_models/embed.py                                            |  403 +++++++++++++++++
 funasr/export/models/language_models/subsampling.py                                      |  185 ++++++++
 egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/README.md   |    1 
 setup.py                                                                                 |    2 
 egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/utils       |    1 
 funasr/export/export_conformer.py                                                        |  151 ++++++
 funasr/export/models/language_models/__init__.py                                         |    0 
 funasr/export/models/__init__.py                                                         |    4 
 funasr/export/models/decoder/xformer_decoder.py                                          |  121 +++++
 egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/finetune.py |   36 +
 funasr/export/models/e2e_asr_conformer.py                                                |   69 +++
 egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.py    |    1 
 egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/demo.py     |   11 
 funasr/export/models/language_models/seq_rnn.py                                          |   84 +++
 17 files changed, 1,373 insertions(+), 1 deletions(-)

diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/README.md b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/README.md
new file mode 120000
index 0000000..bb55ab5
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/README.md
@@ -0,0 +1 @@
+../../TEMPLATE/README.md
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/demo.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/demo.py
new file mode 100644
index 0000000..ea486c8
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/demo.py
@@ -0,0 +1,11 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+    task=Tasks.auto_speech_recognition,
+    model='damo/damo/speech_paraformer_asr-en-16k-vocab4199-pytorch',
+    model_revision="v1.0.1",
+)
+audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav'
+rec_result = inference_pipeline(audio_in=audio_in)
+print(rec_result)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/finetune.py
new file mode 100644
index 0000000..a945b31
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/finetune.py
@@ -0,0 +1,36 @@
+import os
+
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+
+def modelscope_finetune(params):
+    if not os.path.exists(params.output_dir):
+        os.makedirs(params.output_dir, exist_ok=True)
+    # dataset split ["train", "validation"]
+    ds_dict = MsDataset.load(params.data_path)
+    kwargs = dict(
+        model=params.model,
+        data_dir=ds_dict,
+        dataset_type=params.dataset_type,
+        work_dir=params.output_dir,
+        batch_bins=params.batch_bins,
+        max_epoch=params.max_epoch,
+        lr=params.lr)
+    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+    trainer.train()
+
+
+if __name__ == '__main__':
+    params = modelscope_args(model="damo/speech_paraformer_asr-en-16k-vocab4199-pytorch", data_path="./data")
+    params.output_dir = "./checkpoint"              # m妯″瀷淇濆瓨璺緞
+    params.data_path = "./example_data/"            # 鏁版嵁璺緞
+    params.dataset_type = "small"                   # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+    params.batch_bins = 2000                       # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+    params.max_epoch = 50                           # 鏈�澶ц缁冭疆鏁�
+    params.lr = 0.00005                             # 璁剧疆瀛︿範鐜�
+    
+    modelscope_finetune(params)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.py
new file mode 120000
index 0000000..128fc31
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.py
@@ -0,0 +1 @@
+../../TEMPLATE/infer.py
\ No newline at end of file
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.sh b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.sh
new file mode 100644
index 0000000..b552ae8
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/infer.sh
@@ -0,0 +1,103 @@
+#!/usr/bin/env bash
+
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=2
+model="damo/speech_paraformer_asr-en-16k-vocab4199-pytorch"
+data_dir="./data/test"
+output_dir="./results"
+batch_size=64
+gpu_inference=true    # whether to perform gpu decoding
+gpuid_list="0,1"    # set gpus, e.g., gpuid_list="0,1"
+njob=64    # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
+checkpoint_dir=
+checkpoint_name="valid.cer_ctc.ave.pb"
+
+. utils/parse_options.sh || exit 1;
+
+if ${gpu_inference} == "true"; then
+    nj=$(echo $gpuid_list | awk -F "," '{print NF}')
+else
+    nj=$njob
+    batch_size=1
+    gpuid_list=""
+    for JOB in $(seq ${nj}); do
+        gpuid_list=$gpuid_list"-1,"
+    done
+fi
+
+mkdir -p $output_dir/split
+split_scps=""
+for JOB in $(seq ${nj}); do
+    split_scps="$split_scps $output_dir/split/wav.$JOB.scp"
+done
+perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps}
+
+if [ -n "${checkpoint_dir}" ]; then
+  python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
+  model=${checkpoint_dir}/${model}
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
+    echo "Decoding ..."
+    gpuid_list_array=(${gpuid_list//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+        id=$((JOB-1))
+        gpuid=${gpuid_list_array[$id]}
+        mkdir -p ${output_dir}/output.$JOB
+        python infer.py \
+            --model ${model} \
+            --audio_in ${output_dir}/split/wav.$JOB.scp \
+            --output_dir ${output_dir}/output.$JOB \
+            --batch_size ${batch_size} \
+            --gpuid ${gpuid}
+        }&
+    done
+    wait
+
+    mkdir -p ${output_dir}/1best_recog
+    for f in token score text; do
+        if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then
+          for i in $(seq "${nj}"); do
+              cat "${output_dir}/output.${i}/1best_recog/${f}"
+          done | sort -k1 >"${output_dir}/1best_recog/${f}"
+        fi
+    done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
+    echo "Computing WER ..."
+    cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
+    cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
+    python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
+    tail -n 3 ${output_dir}/1best_recog/text.cer
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then
+    echo "SpeechIO TIOBE textnorm"
+    echo "$0 --> Normalizing REF text ..."
+    ./utils/textnorm_zh.py \
+        --has_key --to_upper \
+        ${data_dir}/text \
+        ${output_dir}/1best_recog/ref.txt
+
+    echo "$0 --> Normalizing HYP text ..."
+    ./utils/textnorm_zh.py \
+        --has_key --to_upper \
+        ${output_dir}/1best_recog/text.proc \
+        ${output_dir}/1best_recog/rec.txt
+    grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt
+
+    echo "$0 --> computing WER/CER and alignment ..."
+    ./utils/error_rate_zh \
+        --tokenizer char \
+        --ref ${output_dir}/1best_recog/ref.txt \
+        --hyp ${output_dir}/1best_recog/rec_non_empty.txt \
+        ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt
+    rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt
+fi
+
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/utils b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/utils
new file mode 120000
index 0000000..2ac163f
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr-en-16k-vocab4199-pytorch/utils
@@ -0,0 +1 @@
+../../../../egs/aishell/transformer/utils
\ No newline at end of file
diff --git a/funasr/export/export_conformer.py b/funasr/export/export_conformer.py
new file mode 100644
index 0000000..4980775
--- /dev/null
+++ b/funasr/export/export_conformer.py
@@ -0,0 +1,151 @@
+import json
+from typing import Union, Dict
+from pathlib import Path
+
+import os
+import logging
+import torch
+
+from funasr.export.models import get_model
+import numpy as np
+import random
+from funasr.utils.types import str2bool, str2triple_str
+# torch_version = float(".".join(torch.__version__.split(".")[:2]))
+# assert torch_version > 1.9
+
+class ModelExport:
+    def __init__(
+        self,
+        cache_dir: Union[Path, str] = None,
+        onnx: bool = True,
+        device: str = "cpu",
+        quant: bool = True,
+        fallback_num: int = 0,
+        audio_in: str = None,
+        calib_num: int = 200,
+        model_revision: str = None,
+    ):
+        self.set_all_random_seed(0)
+
+        self.cache_dir = cache_dir
+        self.export_config = dict(
+            feats_dim=560,
+            onnx=False,
+        )
+        
+        self.onnx = onnx
+        self.device = device
+        self.quant = quant
+        self.fallback_num = fallback_num
+        self.frontend = None
+        self.audio_in = audio_in
+        self.calib_num = calib_num
+        self.model_revision = model_revision
+
+    def _export(
+        self,
+        model,
+        model_dir: str = None,
+        verbose: bool = False,
+    ):
+
+        export_dir = model_dir
+        os.makedirs(export_dir, exist_ok=True)
+
+        self.export_config["model_name"] = "model"
+        model = get_model(
+            model,
+            self.export_config,
+        )
+        model.eval()
+
+        if self.onnx:
+            self._export_onnx(model, verbose, export_dir)
+
+        print("output dir: {}".format(export_dir))
+
+    def _export_onnx(self, model, verbose, path):
+        model._export_onnx(verbose, path)
+
+    def set_all_random_seed(self, seed: int):
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.random.manual_seed(seed)
+
+    def parse_audio_in(self, audio_in):
+        
+        wav_list, name_list = [], []
+        if audio_in.endswith(".scp"):
+            f = open(audio_in, 'r')
+            lines = f.readlines()[:self.calib_num]
+            for line in lines:
+                name, path = line.strip().split()
+                name_list.append(name)
+                wav_list.append(path)
+        else:
+            wav_list = [audio_in,]
+            name_list = ["test",]
+        return wav_list, name_list
+    
+    def load_feats(self, audio_in: str = None):
+        import torchaudio
+
+        wav_list, name_list = self.parse_audio_in(audio_in)
+        feats = []
+        feats_len = []
+        for line in wav_list:
+            path = line.strip()
+            waveform, sampling_rate = torchaudio.load(path)
+            if sampling_rate != self.frontend.fs:
+                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
+                                                          new_freq=self.frontend.fs)(waveform)
+            fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
+            feats.append(fbank)
+            feats_len.append(fbank_len)
+        return feats, feats_len
+
+    def export(self,
+               mode: str = None,
+               ):
+
+        if mode.startswith('conformer'):
+            from funasr.tasks.asr import ASRTask
+            config = os.path.join(model_dir, 'config.yaml')
+            model_file = os.path.join(model_dir, 'model.pb')
+            cmvn_file = os.path.join(model_dir, 'am.mvn')
+            model, asr_train_args = ASRTask.build_model_from_file(
+                config, model_file, cmvn_file, 'cpu'
+            )
+            self.frontend = model.frontend
+            self.export_config["feats_dim"] = 560
+
+        self._export(model, self.cache_dir)
+
+if __name__ == '__main__':
+    import argparse
+    parser = argparse.ArgumentParser()
+    # parser.add_argument('--model-name', type=str, required=True)
+    parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
+    parser.add_argument('--export-dir', type=str, required=True)
+    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+    parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+    parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+    parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
+    args = parser.parse_args()
+
+    export_model = ModelExport(
+        cache_dir=args.export_dir,
+        onnx=args.type == 'onnx',
+        device=args.device,
+        quant=args.quantize,
+        fallback_num=args.fallback_num,
+        audio_in=args.audio_in,
+        calib_num=args.calib_num,
+        model_revision=args.model_revision,
+    )
+    for model_name in args.model_name:
+        print("export model: {}".format(model_name))
+        export_model.export(model_name)
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 0e3a782..6177119 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,6 +1,8 @@
 from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
 from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
 from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
+from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
+
 from funasr.models.e2e_vad import E2EVadModel
 from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
 from funasr.models.target_delay_transformer import TargetDelayTransformer
@@ -14,6 +16,8 @@
         return BiCifParaformer_export(model, **export_config)
     elif isinstance(model, Paraformer):
         return Paraformer_export(model, **export_config)
+    elif isinstance(model, Conformer_export):
+        return Conformer_export(model, **export_config)
     elif isinstance(model, E2EVadModel):
         return E2EVadModel_export(model, **export_config)
     elif isinstance(model, PunctuationModel):
diff --git a/funasr/export/models/decoder/xformer_decoder.py b/funasr/export/models/decoder/xformer_decoder.py
new file mode 100644
index 0000000..15199aa
--- /dev/null
+++ b/funasr/export/models/decoder/xformer_decoder.py
@@ -0,0 +1,121 @@
+import os
+
+import torch
+import torch.nn as nn
+
+from funasr.modules.attention import MultiHeadedAttention
+
+from funasr.export.models.modules.decoder_layer import DecoderLayer as OnnxDecoderLayer
+from funasr.export.models.language_models.embed import Embedding
+from funasr.export.models.modules.multihead_att import \
+    OnnxMultiHeadedAttention
+
+from funasr.export.utils.torch_function import MakePadMask, subsequent_mask
+
+class XformerDecoder(nn.Module):
+    def __init__(self,
+                 model,
+                 max_seq_len = 512,
+                 model_name = 'decoder',
+                 onnx: bool = True,):
+        super().__init__()
+        self.embed = Embedding(model.embed, max_seq_len)
+        self.model = model
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = subsequent_mask(max_seq_len, flip=False)
+
+        if isinstance(self.model.decoders[0].self_attn, MultiHeadedAttention):
+            self.num_heads = self.model.decoders[0].self_attn.h
+            self.hidden_size = self.model.decoders[0].self_attn.linear_out.out_features
+
+        # replace multi-head attention module into customized module.
+        for i, d in enumerate(self.model.decoders):
+            # d is DecoderLayer
+            if isinstance(d.self_attn, MultiHeadedAttention):
+                d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
+            if isinstance(d.src_attn, MultiHeadedAttention):
+                d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
+            self.model.decoders[i] = OnnxDecoderLayer(d)
+
+        self.model_name = model_name
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(self,
+                tgt,
+                memory,
+                cache):
+
+        mask = subsequent_mask(tgt.size(-1)).unsqueeze(0)  # (B, T)
+
+        x = self.embed(tgt)
+        mask = self.prepare_mask(mask)
+        new_cache = []
+        for c, decoder in zip(cache, self.model.decoders):
+            x, mask = decoder(x, mask, memory, None, c)
+            new_cache.append(x)
+            x = x[:, 1:, :]
+
+        if self.model.normalize_before:
+            y = self.model.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+
+        if self.model.output_layer is not None:
+            y = torch.log_softmax(self.model.output_layer(y), dim=-1)
+        return y, new_cache
+
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        memory = torch.randn(1, 100, enc_size)
+        cache_num = len(self.model.decoders)
+        cache = [
+            torch.zeros((1, 1, self.model.decoders[0].size))
+            for _ in range(cache_num)
+        ]
+        return (tgt, memory, cache)
+
+    def is_optimizable(self):
+        return True
+
+    def get_input_names(self):
+        cache_num = len(self.model.decoders)
+        return ["tgt", "memory"] + [
+            "cache_%d" % i for i in range(cache_num)
+        ]
+
+    def get_output_names(self):
+        cache_num = len(self.model.decoders)
+        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
+
+    def get_dynamic_axes(self):
+        ret = {
+            "tgt": {0: "tgt_batch", 1: "tgt_length"},
+            "memory": {0: "memory_batch", 1: "memory_length"},
+        }
+        cache_num = len(self.model.decoders)
+        ret.update(
+            {
+                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
+                for d in range(cache_num)
+            }
+        )
+        return ret
+
+    def get_model_config(self, path):
+        return {
+            "dec_type": "XformerDecoder",
+            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
+            "n_layers": len(self.model.decoders),
+            "odim": self.model.decoders[0].size,
+        }
diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
new file mode 100644
index 0000000..45feda5
--- /dev/null
+++ b/funasr/export/models/e2e_asr_conformer.py
@@ -0,0 +1,69 @@
+import os
+import logging
+import torch
+import torch.nn as nn
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
+from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
+
+class Conformer(nn.Module):
+    """
+    export conformer into onnx format
+    """
+
+    def __init__(
+            self,
+            model,
+            max_seq_len=512,
+            feats_dim=560,
+            model_name='model',
+            **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+        if isinstance(model.encoder, ConformerEncoder):
+            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+        elif isinstance(model.decoder, TransformerDecoder):
+            self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
+        
+        self.feats_dim = feats_dim
+        self.model_name = model_name
+
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+    def _export_model(self, model, verbose, path):
+        dummy_input = model.get_dummy_inputs()
+        model_script = model
+        model_path = os.path.join(path, f'{model.model_name}.onnx')
+        if not os.path.exists(model_path):
+            torch.onnx.export(
+                model_script,
+                dummy_input,
+                model_path,
+                verbose=verbose,
+                opset_version=14,
+                input_names=model.get_input_names(),
+                output_names=model.get_output_names(),
+                dynamic_axes=model.get_dynamic_axes()
+            )
+
+    def _export_encoder_onnx(self, verbose, path):
+        model_encoder = self.encoder
+        self._export_model(model_encoder, verbose, path)
+
+    def _export_decoder_onnx(self, verbose, path):
+        model_decoder = self.decoder
+        self._export_model(model_decoder, verbose, path)
+
+    def _export_onnx(self, verbose, path):
+        self._export_encoder_onnx(verbose, path)
+        self._export_decoder_onnx(verbose, path)
\ No newline at end of file
diff --git a/funasr/export/models/language_models/__init__.py b/funasr/export/models/language_models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/export/models/language_models/__init__.py
diff --git a/funasr/export/models/language_models/embed.py b/funasr/export/models/language_models/embed.py
new file mode 100644
index 0000000..57748f2
--- /dev/null
+++ b/funasr/export/models/language_models/embed.py
@@ -0,0 +1,403 @@
+"""Positional Encoding Module."""
+
+import math
+
+import torch
+import torch.nn as nn
+from funasr.modules.embedding import (
+    LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
+    ScaledPositionalEncoding, StreamPositionalEncoding)
+from funasr.modules.subsampling import (
+    Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
+    Conv2dSubsampling8)
+from funasr.modules.subsampling_without_posenc import \
+    Conv2dSubsamplingWOPosEnc
+
+from funasr.export.models.language_models.subsampling import (
+    OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
+    OnnxConv2dSubsampling8)
+
+
+def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
+    if isinstance(pos_emb, LegacyRelPositionalEncoding):
+        return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
+    elif isinstance(pos_emb, ScaledPositionalEncoding):
+        return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
+    elif isinstance(pos_emb, RelPositionalEncoding):
+        return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
+    elif isinstance(pos_emb, PositionalEncoding):
+        return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
+    elif isinstance(pos_emb, StreamPositionalEncoding):
+        return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
+    elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
+        isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
+    ):
+        return pos_emb
+    else:
+        raise ValueError("Embedding model is not supported.")
+
+
+class Embedding(nn.Module):
+    def __init__(self, model, max_seq_len=512, use_cache=True):
+        super().__init__()
+        self.model = model
+        if not isinstance(model, nn.Embedding):
+            if isinstance(model, Conv2dSubsampling):
+                self.model = OnnxConv2dSubsampling(model)
+                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+            elif isinstance(model, Conv2dSubsampling2):
+                self.model = OnnxConv2dSubsampling2(model)
+                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+            elif isinstance(model, Conv2dSubsampling6):
+                self.model = OnnxConv2dSubsampling6(model)
+                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+            elif isinstance(model, Conv2dSubsampling8):
+                self.model = OnnxConv2dSubsampling8(model)
+                self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
+            else:
+                self.model[-1] = get_pos_emb(model[-1], max_seq_len)
+
+    def forward(self, x, mask=None):
+        if mask is None:
+            return self.model(x)
+        else:
+            return self.model(x, mask)
+
+
+def _pre_hook(
+    state_dict,
+    prefix,
+    local_metadata,
+    strict,
+    missing_keys,
+    unexpected_keys,
+    error_msgs,
+):
+    """Perform pre-hook in load_state_dict for backward compatibility.
+
+    Note:
+        We saved self.pe until v.0.5.2 but we have omitted it later.
+        Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
+
+    """
+    k = prefix + "pe"
+    if k in state_dict:
+        state_dict.pop(k)
+
+
+class OnnxPositionalEncoding(torch.nn.Module):
+    """Positional encoding.
+
+    Args:
+        d_model (int): Embedding dimension.
+        dropout_rate (float): Dropout rate.
+        max_seq_len (int): Maximum input length.
+        reverse (bool): Whether to reverse the input position. Only for
+        the class LegacyRelPositionalEncoding. We remove it in the current
+        class RelPositionalEncoding.
+    """
+
+    def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
+        """Construct an PositionalEncoding object."""
+        super(OnnxPositionalEncoding, self).__init__()
+        self.d_model = model.d_model
+        self.reverse = reverse
+        self.max_seq_len = max_seq_len
+        self.xscale = math.sqrt(self.d_model)
+        self._register_load_state_dict_pre_hook(_pre_hook)
+        self.pe = model.pe
+        self.use_cache = use_cache
+        self.model = model
+        if self.use_cache:
+            self.extend_pe()
+        else:
+            self.div_term = torch.exp(
+                torch.arange(0, self.d_model, 2, dtype=torch.float32)
+                * -(math.log(10000.0) / self.d_model)
+            )
+
+    def extend_pe(self):
+        """Reset the positional encodings."""
+        pe_length = len(self.pe[0])
+        if self.max_seq_len < pe_length:
+            self.pe = self.pe[:, : self.max_seq_len]
+        else:
+            self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
+            self.pe = self.model.pe
+
+    def _add_pe(self, x):
+        """Computes positional encoding"""
+        if self.reverse:
+            position = torch.arange(
+                x.size(1) - 1, -1, -1.0, dtype=torch.float32
+            ).unsqueeze(1)
+        else:
+            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+
+        x = x * self.xscale
+        x[:, :, 0::2] += torch.sin(position * self.div_term)
+        x[:, :, 1::2] += torch.cos(position * self.div_term)
+        return x
+
+    def forward(self, x: torch.Tensor):
+        """Add positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+        """
+        if self.use_cache:
+            x = x * self.xscale + self.pe[:, : x.size(1)]
+        else:
+            x = self._add_pe(x)
+        return x
+
+
+class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
+    """Scaled positional encoding module.
+
+    See Sec. 3.2  https://arxiv.org/abs/1809.08895
+
+    Args:
+        d_model (int): Embedding dimension.
+        dropout_rate (float): Dropout rate.
+        max_seq_len (int): Maximum input length.
+
+    """
+
+    def __init__(self, model, max_seq_len=512, use_cache=True):
+        """Initialize class."""
+        super().__init__(model, max_seq_len, use_cache=use_cache)
+        self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+
+    def reset_parameters(self):
+        """Reset parameters."""
+        self.alpha.data = torch.tensor(1.0)
+
+    def _add_pe(self, x):
+        """Computes positional encoding"""
+        if self.reverse:
+            position = torch.arange(
+                x.size(1) - 1, -1, -1.0, dtype=torch.float32
+            ).unsqueeze(1)
+        else:
+            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+
+        x = x * self.alpha
+        x[:, :, 0::2] += torch.sin(position * self.div_term)
+        x[:, :, 1::2] += torch.cos(position * self.div_term)
+        return x
+
+    def forward(self, x):
+        """Add positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+
+        """
+        if self.use_cache:
+            x = x + self.alpha * self.pe[:, : x.size(1)]
+        else:
+            x = self._add_pe(x)
+        return x
+
+
+class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
+    """Relative positional encoding module (old version).
+
+    Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+    See : Appendix B in https://arxiv.org/abs/1901.02860
+
+    Args:
+        d_model (int): Embedding dimension.
+        dropout_rate (float): Dropout rate.
+        max_seq_len (int): Maximum input length.
+
+    """
+
+    def __init__(self, model, max_seq_len=512, use_cache=True):
+        """Initialize class."""
+        super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
+
+    def _get_pe(self, x):
+        """Computes positional encoding"""
+        if self.reverse:
+            position = torch.arange(
+                x.size(1) - 1, -1, -1.0, dtype=torch.float32
+            ).unsqueeze(1)
+        else:
+            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+
+        pe = torch.zeros(x.shape)
+        pe[:, :, 0::2] += torch.sin(position * self.div_term)
+        pe[:, :, 1::2] += torch.cos(position * self.div_term)
+        return pe
+
+    def forward(self, x):
+        """Compute positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+            torch.Tensor: Positional embedding tensor (1, time, `*`).
+
+        """
+        x = x * self.xscale
+        if self.use_cache:
+            pos_emb = self.pe[:, : x.size(1)]
+        else:
+            pos_emb = self._get_pe(x)
+        return x, pos_emb
+
+
+class OnnxRelPositionalEncoding(torch.nn.Module):
+    """Relative positional encoding module (new implementation).
+    Details can be found in https://github.com/espnet/espnet/pull/2816.
+    See : Appendix B in https://arxiv.org/abs/1901.02860
+    Args:
+        d_model (int): Embedding dimension.
+        dropout_rate (float): Dropout rate.
+        max_seq_len (int): Maximum input length.
+    """
+
+    def __init__(self, model, max_seq_len=512, use_cache=True):
+        """Construct an PositionalEncoding object."""
+        super(OnnxRelPositionalEncoding, self).__init__()
+        self.d_model = model.d_model
+        self.xscale = math.sqrt(self.d_model)
+        self.pe = None
+        self.use_cache = use_cache
+        if self.use_cache:
+            self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
+        else:
+            self.div_term = torch.exp(
+                torch.arange(0, self.d_model, 2, dtype=torch.float32)
+                * -(math.log(10000.0) / self.d_model)
+            )
+
+    def extend_pe(self, x):
+        """Reset the positional encodings."""
+        if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.dtype != x.dtype or self.pe.device != x.device:
+                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+            return
+        # Suppose `i` means to the position of query vecotr and `j` means the
+        # position of key vector. We use position relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i<j).
+        pe_positive = torch.zeros(x.size(1), self.d_model)
+        pe_negative = torch.zeros(x.size(1), self.d_model)
+        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.d_model)
+        )
+        pe_positive[:, 0::2] = torch.sin(position * div_term)
+        pe_positive[:, 1::2] = torch.cos(position * div_term)
+        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+
+        # Reserve the order of positive indices and concat both positive and
+        # negative indices. This is used to support the shifting trick
+        # as in https://arxiv.org/abs/1901.02860
+        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+        pe_negative = pe_negative[1:].unsqueeze(0)
+        pe = torch.cat([pe_positive, pe_negative], dim=1)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def _get_pe(self, x):
+        pe_positive = torch.zeros(x.size(1), self.d_model)
+        pe_negative = torch.zeros(x.size(1), self.d_model)
+        theta = (
+            torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
+        )
+        pe_positive[:, 0::2] = torch.sin(theta)
+        pe_positive[:, 1::2] = torch.cos(theta)
+        pe_negative[:, 0::2] = -1 * torch.sin(theta)
+        pe_negative[:, 1::2] = torch.cos(theta)
+
+        # Reserve the order of positive indices and concat both positive and
+        # negative indices. This is used to support the shifting trick
+        # as in https://arxiv.org/abs/1901.02860
+        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+        pe_negative = pe_negative[1:].unsqueeze(0)
+        return torch.cat([pe_positive, pe_negative], dim=1)
+
+    def forward(self, x: torch.Tensor, use_cache=True):
+        """Add positional encoding.
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+        """
+        x = x * self.xscale
+        if self.use_cache:
+            pos_emb = self.pe[
+                :,
+                self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
+            ]
+        else:
+            pos_emb = self._get_pe(x)
+        return x, pos_emb
+
+
+class OnnxStreamPositionalEncoding(torch.nn.Module):
+    """Streaming Positional encoding."""
+
+    def __init__(self, model, max_seq_len=5000, use_cache=True):
+        """Construct an PositionalEncoding object."""
+        super(StreamPositionalEncoding, self).__init__()
+        self.use_cache = use_cache
+        self.d_model = model.d_model
+        self.xscale = model.xscale
+        self.pe = model.pe
+        self.use_cache = use_cache
+        self.max_seq_len = max_seq_len
+        if self.use_cache:
+            self.extend_pe()
+        else:
+            self.div_term = torch.exp(
+                torch.arange(0, self.d_model, 2, dtype=torch.float32)
+                * -(math.log(10000.0) / self.d_model)
+            )
+        self._register_load_state_dict_pre_hook(_pre_hook)
+
+    def extend_pe(self):
+        """Reset the positional encodings."""
+        pe_length = len(self.pe[0])
+        if self.max_seq_len < pe_length:
+            self.pe = self.pe[:, : self.max_seq_len]
+        else:
+            self.model.extend_pe(self.max_seq_len)
+            self.pe = self.model.pe
+
+    def _add_pe(self, x, start_idx):
+        position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
+        x = x * self.xscale
+        x[:, :, 0::2] += torch.sin(position * self.div_term)
+        x[:, :, 1::2] += torch.cos(position * self.div_term)
+        return x
+
+    def forward(self, x: torch.Tensor, start_idx: int = 0):
+        """Add positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+
+        """
+        if self.use_cache:
+            return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
+        else:
+            return self._add_pe(x, start_idx)
diff --git a/funasr/export/models/language_models/seq_rnn.py b/funasr/export/models/language_models/seq_rnn.py
new file mode 100644
index 0000000..ecff4b8
--- /dev/null
+++ b/funasr/export/models/language_models/seq_rnn.py
@@ -0,0 +1,84 @@
+import os
+
+import torch
+import torch.nn as nn
+
+class SequentialRNNLM(nn.Module):
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        self.encoder = model.encoder
+        self.rnn = model.rnn
+        self.rnn_type = model.rnn_type
+        self.decoder = model.decoder
+        self.nlayers = model.nlayers
+        self.nhid = model.nhid
+        self.model_name = "seq_rnnlm"
+
+    def forward(self, y, hidden1, hidden2=None):
+        # batch_score function.
+        emb = self.encoder(y)
+        if self.rnn_type == "LSTM":
+            output, (hidden1, hidden2) = self.rnn(emb, (hidden1, hidden2))
+        else:
+            output, hidden1 = self.rnn(emb, hidden1)
+
+        decoded = self.decoder(
+            output.contiguous().view(output.size(0) * output.size(1), output.size(2))
+        )
+        if self.rnn_type == "LSTM":
+            return (
+                decoded.view(output.size(0), output.size(1), decoded.size(1)),
+                hidden1,
+                hidden2,
+            )
+        else:
+            return (
+                decoded.view(output.size(0), output.size(1), decoded.size(1)),
+                hidden1,
+            )
+
+    def get_dummy_inputs(self):
+        tgt = torch.LongTensor([0, 1]).unsqueeze(0)
+        hidden = torch.randn(self.nlayers, 1, self.nhid)
+        if self.rnn_type == "LSTM":
+            return (tgt, hidden, hidden)
+        else:
+            return (tgt, hidden)
+
+    def get_input_names(self):
+        if self.rnn_type == "LSTM":
+            return ["x", "in_hidden1", "in_hidden2"]
+        else:
+            return ["x", "in_hidden1"]
+
+    def get_output_names(self):
+        if self.rnn_type == "LSTM":
+            return ["y", "out_hidden1", "out_hidden2"]
+        else:
+            return ["y", "out_hidden1"]
+
+    def get_dynamic_axes(self):
+        ret = {
+            "x": {0: "x_batch", 1: "x_length"},
+            "y": {0: "y_batch"},
+            "in_hidden1": {1: "hidden1_batch"},
+            "out_hidden1": {1: "out_hidden1_batch"},
+        }
+        if self.rnn_type == "LSTM":
+            ret.update(
+                {
+                    "in_hidden2": {1: "hidden2_batch"},
+                    "out_hidden2": {1: "out_hidden2_batch"},
+                }
+            )
+        return ret
+
+    def get_model_config(self, path):
+        return {
+            "use_lm": True,
+            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
+            "lm_type": "SequentialRNNLM",
+            "rnn_type": self.rnn_type,
+            "nhid": self.nhid,
+            "nlayers": self.nlayers,
+        }
diff --git a/funasr/export/models/language_models/subsampling.py b/funasr/export/models/language_models/subsampling.py
new file mode 100644
index 0000000..e71e127
--- /dev/null
+++ b/funasr/export/models/language_models/subsampling.py
@@ -0,0 +1,185 @@
+"""Subsampling layer definition."""
+
+import torch
+
+
+class OnnxConv2dSubsampling(torch.nn.Module):
+    """Convolutional 2D subsampling (to 1/4 length).
+
+    Args:
+        idim (int): Input dimension.
+        odim (int): Output dimension.
+        dropout_rate (float): Dropout rate.
+        pos_enc (torch.nn.Module): Custom position encoding layer.
+
+    """
+
+    def __init__(self, model):
+        """Construct an Conv2dSubsampling object."""
+        super().__init__()
+        self.conv = model.conv
+        self.out = model.out
+
+    def forward(self, x, x_mask):
+        """Subsample x.
+
+        Args:
+            x (torch.Tensor): Input tensor (#batch, time, idim).
+            x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+        Returns:
+            torch.Tensor: Subsampled tensor (#batch, time', odim),
+                where time' = time // 4.
+            torch.Tensor: Subsampled mask (#batch, 1, time'),
+                where time' = time // 4.
+
+        """
+        x = x.unsqueeze(1)  # (b, c, t, f)
+        x = self.conv(x)
+        b, c, t, f = x.size()
+        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        if x_mask is None:
+            return x, None
+        return x, x_mask[:, :-2:2][:, :-2:2]
+
+    def __getitem__(self, key):
+        """Get item.
+
+        When reset_parameters() is called, if use_scaled_pos_enc is used,
+            return the positioning encoding.
+
+        """
+        if key != -1:
+            raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+        return self.out[key]
+
+
+class OnnxConv2dSubsampling2(torch.nn.Module):
+    """Convolutional 2D subsampling (to 1/2 length).
+
+    Args:
+        idim (int): Input dimension.
+        odim (int): Output dimension.
+        dropout_rate (float): Dropout rate.
+        pos_enc (torch.nn.Module): Custom position encoding layer.
+
+    """
+
+    def __init__(self, model):
+        """Construct an Conv2dSubsampling object."""
+        super().__init__()
+        self.conv = model.conv
+        self.out = model.out
+
+    def forward(self, x, x_mask):
+        """Subsample x.
+
+        Args:
+            x (torch.Tensor): Input tensor (#batch, time, idim).
+            x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+        Returns:
+            torch.Tensor: Subsampled tensor (#batch, time', odim),
+                where time' = time // 2.
+            torch.Tensor: Subsampled mask (#batch, 1, time'),
+                where time' = time // 2.
+
+        """
+        x = x.unsqueeze(1)  # (b, c, t, f)
+        x = self.conv(x)
+        b, c, t, f = x.size()
+        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        if x_mask is None:
+            return x, None
+        return x, x_mask[:, :-2:2][:, :-2:1]
+
+    def __getitem__(self, key):
+        """Get item.
+
+        When reset_parameters() is called, if use_scaled_pos_enc is used,
+            return the positioning encoding.
+
+        """
+        if key != -1:
+            raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
+        return self.out[key]
+
+
+class OnnxConv2dSubsampling6(torch.nn.Module):
+    """Convolutional 2D subsampling (to 1/6 length).
+
+    Args:
+        idim (int): Input dimension.
+        odim (int): Output dimension.
+        dropout_rate (float): Dropout rate.
+        pos_enc (torch.nn.Module): Custom position encoding layer.
+
+    """
+
+    def __init__(self, model):
+        """Construct an Conv2dSubsampling object."""
+        super().__init__()
+        self.conv = model.conv
+        self.out = model.out
+
+    def forward(self, x, x_mask):
+        """Subsample x.
+
+        Args:
+            x (torch.Tensor): Input tensor (#batch, time, idim).
+            x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+        Returns:
+            torch.Tensor: Subsampled tensor (#batch, time', odim),
+                where time' = time // 6.
+            torch.Tensor: Subsampled mask (#batch, 1, time'),
+                where time' = time // 6.
+
+        """
+        x = x.unsqueeze(1)  # (b, c, t, f)
+        x = self.conv(x)
+        b, c, t, f = x.size()
+        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        if x_mask is None:
+            return x, None
+        return x, x_mask[:, :-2:2][:, :-4:3]
+
+
+class OnnxConv2dSubsampling8(torch.nn.Module):
+    """Convolutional 2D subsampling (to 1/8 length).
+
+    Args:
+        idim (int): Input dimension.
+        odim (int): Output dimension.
+        dropout_rate (float): Dropout rate.
+        pos_enc (torch.nn.Module): Custom position encoding layer.
+
+    """
+
+    def __init__(self, model):
+        """Construct an Conv2dSubsampling object."""
+        super().__init__()
+        self.conv = model.conv
+        self.out = model.out
+
+    def forward(self, x, x_mask):
+        """Subsample x.
+
+        Args:
+            x (torch.Tensor): Input tensor (#batch, time, idim).
+            x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+        Returns:
+            torch.Tensor: Subsampled tensor (#batch, time', odim),
+                where time' = time // 8.
+            torch.Tensor: Subsampled mask (#batch, 1, time'),
+                where time' = time // 8.
+
+        """
+        x = x.unsqueeze(1)  # (b, c, t, f)
+        x = self.conv(x)
+        b, c, t, f = x.size()
+        x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
+        if x_mask is None:
+            return x, None
+        return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]
diff --git a/funasr/export/models/language_models/transformer.py b/funasr/export/models/language_models/transformer.py
new file mode 100644
index 0000000..ebf0574
--- /dev/null
+++ b/funasr/export/models/language_models/transformer.py
@@ -0,0 +1,110 @@
+import os
+
+import torch
+import torch.nn as nn
+from funasr.modules.vgg2l import import VGG2L
+from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.subsampling import (
+    Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8)
+
+from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as OnnxEncoderLayer
+from funasr.export.models.language_models.embed import Embedding
+from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
+
+from funasr.export.utils.torch_function import MakePadMask
+
+class TransformerLM(nn.Module, AbsExportModel):
+    def __init__(self, model, max_seq_len=512, **kwargs):
+        super().__init__()
+        self.embed = Embedding(model.embed, max_seq_len)
+        self.encoder = model.encoder
+        self.decoder = model.decoder
+        self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        # replace multihead attention module into customized module.
+        for i, d in enumerate(self.encoder.encoders):
+            # d is EncoderLayer
+            if isinstance(d.self_attn, MultiHeadedAttention):
+                d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
+            self.encoder.encoders[i] = OnnxEncoderLayer(d)
+
+        self.model_name = "transformer_lm"
+        self.num_heads = self.encoder.encoders[0].self_attn.h
+        self.hidden_size = self.encoder.encoders[0].self_attn.linear_out.out_features
+
+    def prepare_mask(self, mask):
+        if len(mask.shape) == 2:
+            mask = mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask = mask[:, None, :]
+        mask = 1 - mask
+        return mask * -10000.0
+
+    def forward(self, y, cache):
+        feats_length = torch.ones(y.shape).sum(dim=-1).type(torch.long)
+        mask = self.make_pad_mask(feats_length)  # (B, T)
+        mask = (y != 0) * mask
+
+        xs = self.embed(y)
+        # forward_one_step of Encoder
+        if isinstance(
+            self.encoder.embed,
+            (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
+        ):
+            xs, mask = self.encoder.embed(xs, mask)
+        else:
+            xs = self.encoder.embed(xs)
+
+        new_cache = []
+        mask = self.prepare_mask(mask)
+        for c, e in zip(cache, self.encoder.encoders):
+            xs, mask = e(xs, mask, c)
+            new_cache.append(xs)
+
+        if self.encoder.normalize_before:
+            xs = self.encoder.after_norm(xs)
+
+        h = self.decoder(xs[:, -1])
+        return h, new_cache
+
+    def get_dummy_inputs(self):
+        tgt = torch.LongTensor([1]).unsqueeze(0)
+        cache = [
+            torch.zeros((1, 1, self.encoder.encoders[0].size))
+            for _ in range(len(self.encoder.encoders))
+        ]
+        return (tgt, cache)
+
+    def is_optimizable(self):
+        return True
+
+    def get_input_names(self):
+        return ["tgt"] + ["cache_%d" % i for i in range(len(self.encoder.encoders))]
+
+    def get_output_names(self):
+        return ["y"] + ["out_cache_%d" % i for i in range(len(self.encoder.encoders))]
+
+    def get_dynamic_axes(self):
+        ret = {"tgt": {0: "tgt_batch", 1: "tgt_length"}}
+        ret.update(
+            {
+                "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
+                for d in range(len(self.encoder.encoders))
+            }
+        )
+        ret.update(
+            {
+                "out_cache_%d"
+                % d: {0: "out_cache_%d_batch" % d, 1: "out_cache_%d_length" % d}
+                for d in range(len(self.encoder.encoders))
+            }
+        )
+        return ret
+
+    def get_model_config(self, path):
+        return {
+            "use_lm": True,
+            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
+            "lm_type": "TransformerLM",
+            "odim": self.encoder.encoders[0].size,
+            "nlayers": len(self.encoder.encoders),
+        }
diff --git a/funasr/modules/vgg2l.py b/funasr/modules/vgg2l.py
new file mode 100644
index 0000000..fb45d0b
--- /dev/null
+++ b/funasr/modules/vgg2l.py
@@ -0,0 +1,92 @@
+"""VGG2L module definition for custom encoder."""
+
+from typing import Tuple, Union
+
+import torch
+
+
+class VGG2L(torch.nn.Module):
+    """VGG2L module for custom encoder.
+
+    Args:
+        idim: Input dimension.
+        odim: Output dimension.
+        pos_enc: Positional encoding class.
+
+    """
+
+    def __init__(self, idim: int, odim: int, pos_enc: torch.nn.Module = None):
+        """Construct a VGG2L object."""
+        super().__init__()
+
+        self.vgg2l = torch.nn.Sequential(
+            torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
+            torch.nn.ReLU(),
+            torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
+            torch.nn.ReLU(),
+            torch.nn.MaxPool2d((3, 2)),
+            torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
+            torch.nn.ReLU(),
+            torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
+            torch.nn.ReLU(),
+            torch.nn.MaxPool2d((2, 2)),
+        )
+
+        if pos_enc is not None:
+            self.output = torch.nn.Sequential(
+                torch.nn.Linear(128 * ((idim // 2) // 2), odim), pos_enc
+            )
+        else:
+            self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
+
+    def forward(
+        self, feats: torch.Tensor, feats_mask: torch.Tensor
+    ) -> Union[
+        Tuple[torch.Tensor, torch.Tensor],
+        Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor],
+    ]:
+        """Forward VGG2L bottleneck.
+
+        Args:
+            feats: Feature sequences. (B, F, D_feats)
+            feats_mask: Mask of feature sequences. (B, 1, F)
+
+        Returns:
+            vgg_output: VGG output sequences.
+                   (B, sub(F), D_out) or ((B, sub(F), D_out), (B, sub(F), D_att))
+            vgg_mask: Mask of VGG output sequences. (B, 1, sub(F))
+
+        """
+        feats = feats.unsqueeze(1)
+        vgg_output = self.vgg2l(feats)
+
+        b, c, t, f = vgg_output.size()
+
+        vgg_output = self.output(
+            vgg_output.transpose(1, 2).contiguous().view(b, t, c * f)
+        )
+
+        if feats_mask is not None:
+            vgg_mask = self.create_new_mask(feats_mask)
+        else:
+            vgg_mask = feats_mask
+
+        return vgg_output, vgg_mask
+
+    def create_new_mask(self, feats_mask: torch.Tensor) -> torch.Tensor:
+        """Create a subsampled mask of feature sequences.
+
+        Args:
+            feats_mask: Mask of feature sequences. (B, 1, F)
+
+        Returns:
+            vgg_mask: Mask of VGG2L output sequences. (B, 1, sub(F))
+
+        """
+        vgg1_t_len = feats_mask.size(2) - (feats_mask.size(2) % 3)
+        vgg_mask = feats_mask[:, :, :vgg1_t_len][:, :, ::3]
+
+        vgg2_t_len = vgg_mask.size(2) - (vgg_mask.size(2) % 2)
+        vgg_mask = vgg_mask[:, :, :vgg2_t_len][:, :, ::2]
+
+        return vgg_mask
diff --git a/setup.py b/setup.py
index 9c36dd0..43c734c 100644
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,7 @@
         "librosa",
         "jamo",  # For kss
         "PyYAML>=5.1.2",
-        "soundfile>=0.10.2",
+        "soundfile>=0.12.1",
         "h5py>=2.10.0",
         "kaldiio>=2.17.0",
         "torch_complex",

--
Gitblit v1.9.1