From f2d8ded57f6403696001d39dd07a1396e5a03658 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 01:24:43 +0800
Subject: [PATCH] export onnx (#1455)

---
 funasr/models/paraformer/model.py                                          |   77 ++
 funasr/bin/export.py                                                       |   40 +
 funasr/models/llm_asr_nar/model.py                                         |   13 
 funasr/models/sanm/multihead_att.py                                        |  243 +++++++
 examples/industrial_data_pretraining/paraformer/export.py                  |   13 
 funasr/models/qwen_audio/audio.py                                          |  446 +++++++++++++
 funasr/models/qwen_audio/model.py                                          |    2 
 funasr/utils/export_utils.py                                               |   74 ++
 examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml |   94 ++
 examples/industrial_data_pretraining/paraformer/export.sh                  |   18 
 funasr/datasets/llm_datasets_qwenaudio/__init__.py                         |    0 
 funasr/models/sanm/attention.py                                            |   58 +
 funasr/models/whisper/template.yaml                                        |   31 
 funasr/utils/torch_function.py                                             |   80 ++
 funasr/models/fsmn_vad_streaming/model.py                                  |   40 +
 funasr/auto/auto_model.py                                                  |   54 +
 funasr/datasets/llm_datasets/preprocessor.py                               |    2 
 funasr/datasets/llm_datasets_qwenaudio/datasets.py                         |  150 ++++
 funasr/models/paraformer/cif_predictor.py                                  |  116 +++
 funasr/models/fsmn_vad_streaming/encoder.py                                |   52 +
 examples/industrial_data_pretraining/fsmn_vad_streaming/export.py          |   12 
 funasr/models/paraformer/decoder.py                                        |  171 +++++
 funasr/utils/misc.py                                                       |    7 
 /dev/null                                                                  |   37 -
 examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh          |   19 
 funasr/models/sanm/encoder.py                                              |  132 +++
 26 files changed, 1,914 insertions(+), 67 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
new file mode 100644
index 0000000..d259104
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4", export_model=True)
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
new file mode 100644
index 0000000..0bb4617
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
@@ -0,0 +1,19 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+model_revision="v2.0.4"
+
+python funasr/bin/export.py \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu" \
+++debug=false
diff --git a/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml b/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml
new file mode 100644
index 0000000..4603f80
--- /dev/null
+++ b/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml
@@ -0,0 +1,94 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: LLMASR
+model_conf:
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: true
+
+# encoder
+audio_encoder: iic/Whisper-large-v2 #iic/Whisper-large-v3
+audio_encoder_conf:
+    hub: ms
+    freeze: true
+    init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+
+llm: Vicuna
+llm_conf:
+  hub: hf
+  freeze: true
+  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+
+audio_adaptor: Linear
+audio_adaptor_conf:
+  downsample_rate: 5
+  llm_dim: 4096
+  encoder_dim: 512
+
+# frontend related
+frontend: WhisperFrontend
+frontend_conf:
+    fs: 16000
+    whisper_model: large-v2
+    do_pad_trim: true
+    permute: true # true: [bs, frames, dims]; false: [bs, dims, frames]
+
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 15
+  keep_nbest_models: 10
+  log_interval: 10
+
+optim: adamw
+optim_conf:
+   lr: 0.0001
+   weight_decay: 0.000001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 1000
+
+dataset: AudioLLMQwenAudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: CustomDistributedBatchSampler
+    batch_type: example # example or length
+    batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 3000 # filter samples if source_token_len+target_token_len > max_token_length,
+    shuffle: True
+    num_workers: 4
+    preprocessor_text: TextPreprocessRemovePunctuation
+    audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
+    audio_encoder_downsample_rate: 2
+    prompt: "<|startoftranscription|><|zh|><|transcribe|><|zh|><|notimestamps|><|wo_itn|>"
+
+
+
+tokenizer: HuggingfaceTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
new file mode 100644
index 0000000..613c3a9
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+                  model_revision="v2.0.4", export_model=True)
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
diff --git a/examples/industrial_data_pretraining/paraformer/export.sh b/examples/industrial_data_pretraining/paraformer/export.sh
new file mode 100644
index 0000000..9f45a5a
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer/export.sh
@@ -0,0 +1,18 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_revision="v2.0.4"
+
+python funasr/bin/export.py \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu" \
+++debug=false
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 9ae9f18..d7b6cb9 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -14,6 +14,7 @@
 import numpy as np
 from tqdm import tqdm
 
+from funasr.utils.misc import deep_update
 from funasr.register import tables
 from funasr.utils.load_utils import load_bytes
 from funasr.download.file import download_from_url
@@ -23,12 +24,13 @@
 from funasr.utils.load_utils import load_audio_text_image_video
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
+from funasr.utils import export_utils
 try:
+    from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
     from funasr.models.campplus.cluster_backend import ClusterBackend
 except:
     print("If you want to use the speaker diarization, please `pip install hdbscan`")
-import pdb
+
 
 def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
     """
@@ -41,7 +43,7 @@
     """
     data_list = []
     key_list = []
-    filelist = [".scp", ".txt", ".json", ".jsonl"]
+    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
     
     chars = string.ascii_letters + string.digits
     if isinstance(data_in, str) and data_in.startswith('http'): # url
@@ -98,7 +100,9 @@
     def __init__(self, **kwargs):
         if not kwargs.get("disable_log", True):
             tables.print()
-        
+        if kwargs.get("export_model", False):
+            os.environ['EXPORTING_MODEL'] = 'TRUE'
+            
         model, kwargs = self.build_model(**kwargs)
         
         # if vad_model is not None, build vad model else None
@@ -203,7 +207,7 @@
     
     def __call__(self, *args, **cfg):
         kwargs = self.kwargs
-        kwargs.update(cfg)
+        deep_update(kwargs, cfg)
         res = self.model(*args, kwargs)
         return res
 
@@ -216,7 +220,7 @@
         
     def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
         kwargs = self.kwargs if kwargs is None else kwargs
-        kwargs.update(cfg)
+        deep_update(kwargs, cfg)
         model = self.model if model is None else model
         model.eval()
 
@@ -279,7 +283,7 @@
     def inference_with_vad(self, input, input_len=None, **cfg):
         kwargs = self.kwargs
         # step.1: compute the vad model
-        self.vad_kwargs.update(cfg)
+        deep_update(self.vad_kwargs, cfg)
         beg_vad = time.time()
         res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
         end_vad = time.time()
@@ -287,7 +291,7 @@
 
         # step.2 compute asr model
         model = self.model
-        kwargs.update(cfg)
+        deep_update(kwargs, cfg)
         batch_size = int(kwargs.get("batch_size_s", 300))*1000
         batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
         kwargs["batch_size"] = batch_size
@@ -399,7 +403,7 @@
                     if return_raw_text:
                         result['raw_text'] = ''
                 else:
-                    self.punc_kwargs.update(cfg)
+                    deep_update(self.punc_kwargs, cfg)
                     punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
                     raw_text = copy.copy(result["text"])
                     if return_raw_text: result['raw_text'] = raw_text
@@ -467,3 +471,35 @@
         #                      f"time_escape_all: {time_escape_total_all_samples:0.3f}")
         return results_ret_list
 
+    def export(self, input=None,
+               type : str = "onnx",
+               quantize: bool = False,
+               fallback_num: int = 5,
+               calib_num: int = 100,
+               opset_version: int = 14,
+               **cfg):
+        os.environ['EXPORTING_MODEL'] = 'TRUE'
+        kwargs = self.kwargs
+        deep_update(kwargs, cfg)
+        del kwargs["model"]
+        model = self.model
+        model.eval()
+
+        batch_size = 1
+
+        key_list, data_list = prepare_data_iterator(input, input_len=None, data_type=kwargs.get("data_type", None), key=None)
+
+        with torch.no_grad():
+            
+            if type == "onnx":
+                export_dir = export_utils.export_onnx(
+                                        model=model,
+                                        data_in=data_list,
+                                        **kwargs)
+            else:
+                export_dir = export_utils.export_torchscripts(
+                                        model=model,
+                                        data_in=data_list,
+                                        **kwargs)
+
+        return export_dir
\ No newline at end of file
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
new file mode 100644
index 0000000..68acc17
--- /dev/null
+++ b/funasr/bin/export.py
@@ -0,0 +1,40 @@
+import os
+import hydra
+import logging
+from omegaconf import DictConfig, OmegaConf, ListConfig
+
+from funasr.auto.auto_model import AutoModel
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+    def to_plain_list(cfg_item):
+        if isinstance(cfg_item, ListConfig):
+            return OmegaConf.to_container(cfg_item, resolve=True)
+        elif isinstance(cfg_item, DictConfig):
+            return {k: to_plain_list(v) for k, v in cfg_item.items()}
+        else:
+            return cfg_item
+    
+    kwargs = to_plain_list(cfg)
+    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+
+    logging.basicConfig(level=log_level)
+
+    if kwargs.get("debug", False):
+        import pdb; pdb.set_trace()
+
+    
+    model = AutoModel(export_model=True, **kwargs)
+    res = model.export(input=kwargs.get("input", None),
+                       type=kwargs.get("type", "onnx"),
+                       quantize=kwargs.get("quantize", False),
+                       fallback_num=kwargs.get("fallback-num", 5),
+                       calib_num=kwargs.get("calib_num", 100),
+                       opset_version=kwargs.get("opset_version", 14),
+                       )
+    print(res)
+
+
+if __name__ == '__main__':
+    main_hydra()
\ No newline at end of file
diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py
index 9f20672..4cbef33 100644
--- a/funasr/datasets/llm_datasets/preprocessor.py
+++ b/funasr/datasets/llm_datasets/preprocessor.py
@@ -18,7 +18,7 @@
 
 
 @tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
-class TextPreprocessSegDict(nn.Module):
+class TextPreprocessRemovePunctuation(nn.Module):
 	def __init__(self,
 	             **kwargs):
 		super().__init__()
diff --git a/funasr/datasets/llm_datasets_qwenaudio/__init__.py b/funasr/datasets/llm_datasets_qwenaudio/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/llm_datasets_qwenaudio/__init__.py
diff --git a/funasr/datasets/llm_datasets_qwenaudio/datasets.py b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
new file mode 100644
index 0000000..674217c
--- /dev/null
+++ b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
@@ -0,0 +1,150 @@
+import torch
+import copy
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "AudioLLMQwenAudioDataset")
+class AudioLLMQwenAudioDataset(torch.utils.data.Dataset):
+    """
+    AudioLLMDataset
+    """
+    
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                 **kwargs):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+        self.preprocessor_text = preprocessor_text
+        
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
+        
+        self.float_pad_value = float_pad_value
+        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        # self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+        self.prompt_af = ""
+        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+        self.int_pad_value = self.IGNORE_INDEX
+        self.audio_adaptor_downsample_rate = kwargs.get("audio_adaptor_downsample_rate", 5)
+        self.audio_encoder_downsample_rate = kwargs.get("audio_encoder_downsample_rate", 2)
+        self.prompt_template = "{}"
+        self.answer_template = "{}"
+        
+    
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+    
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+    
+    def __len__(self):
+        return len(self.index_ds)
+    
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio_text_image_video(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
+                                               is_final=True)  # speech: [b, T, d]
+        speech = speech.squeeze(0)
+
+        audio_pseudo_length = (speech.shape[0] + 1) // self.audio_adaptor_downsample_rate // self.audio_encoder_downsample_rate
+        audio_pseudo = torch.full((audio_pseudo_length,), -1) # placeholder
+        
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+
+        self.prompt_pre = self.prompt_template.format(self.prompt)
+        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre)  # [bos,prompt]
+        prompt_pre_length = len(prompt_ids_pre)
+        
+        # input
+        input = self.answer_template.format(target.lower())
+        prompt_input = "{}{}".format(self.prompt_pre, input)
+        prompt_input_ids = self.tokenizer.encode(prompt_input) # [bos, prompt, input]
+        # audio_length = len(prompt_input_ids) - prompt_pre_length
+        input_ids = prompt_input_ids + [self.tokenizer.pad_token_id] # [bos, prompt, input, pad]
+        input_ids_length = len(input_ids)
+        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [bos, prompt, input, pad]
+        input_ids = torch.cat((audio_pseudo, input_ids)) # [audio, bos, prompt, input, pad]
+        # input_ids[:audio_pseudo_length] = -1 # [-1, bos, prompt, input, pad]
+        attention_mask = input_ids.ge(-1)  # [true, true, true, true, true], length mask
+        # input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
+        # attention_mask = input_ids.ge(-1)  # [true, true, true, true], length mask
+        
+        # label
+        answer = self.answer_template.format(target.lower())
+        prompt_answer = "{}{}".format(self.prompt_pre, answer)
+        prompt_answer_ids = self.tokenizer.encode(prompt_answer)
+        # answer_length = len(prompt_answer_ids) - prompt_pre_length
+        labels_ids = copy.deepcopy(prompt_answer_ids) + [self.tokenizer.eos_token_id]
+        labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, answer, eos]
+        labels_ids = torch.cat((audio_pseudo, labels_ids))  # [audio, bos, prompt, answer, eos]
+        labels_ids[:audio_pseudo_length+prompt_pre_length] = -1 # [-1, -1, -1, answer, eos]
+        # labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
+        label_mask = labels_ids.ge(0)  # [false, false, false, true, true]
+        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100, -100, -100, answer, eos]
+        
+        # audio_mask for input_ids
+        audio_mask = [1]*audio_pseudo_length + [0]*input_ids_length
+        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+        
+        ids = self.tokenizer.encode(target)  # token ids is different from labels_ids
+        text = torch.tensor(ids, dtype=torch.int64)
+        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+        
+        return {"speech": speech,
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "input_ids": input_ids,
+                "attention_mask": attention_mask,
+                "labels_ids": labels_ids,
+                "label_mask": label_mask,
+                "audio_mask": audio_mask,
+                }
+    
+    def collator(self, samples: list = None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+        
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64:
+                    
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+                
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+        return outputs
diff --git a/funasr/datasets/llm_datasets_vicuna/preprocessor.py b/funasr/datasets/llm_datasets_vicuna/preprocessor.py
deleted file mode 100644
index 9f20672..0000000
--- a/funasr/datasets/llm_datasets_vicuna/preprocessor.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import os
-import json
-import torch
-import logging
-import concurrent.futures
-import librosa
-import torch.distributed as dist
-from typing import Collection
-import torch
-import torchaudio
-from torch import nn
-import random
-import re
-import string
-from funasr.tokenizer.cleaner import TextCleaner
-from funasr.register import tables
-
-
-
-@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
-class TextPreprocessSegDict(nn.Module):
-	def __init__(self,
-	             **kwargs):
-		super().__init__()
-		
-	
-	def forward(self, text, **kwargs):
-		# 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
-		en_punct = string.punctuation
-		# 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
-		cn_punct = '銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路'
-		# 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
-		all_punct = en_punct + cn_punct
-		# 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
-		punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
-		# 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
-		return punct_pattern.sub('', text)
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
index ae91852..e7c0e8b 100755
--- a/funasr/models/fsmn_vad_streaming/encoder.py
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -1,5 +1,6 @@
 from typing import Tuple, Dict
 import copy
+import os
 
 import numpy as np
 import torch
@@ -134,6 +135,25 @@
         x3 = self.affine(x2)
         x4 = self.relu(x3)
         return x4
+class BasicBlock_export(nn.Module):
+    def __init__(self,
+                 model,
+                 ):
+        super(BasicBlock_export, self).__init__()
+        self.linear = model.linear
+        self.fsmn_block = model.fsmn_block
+        self.affine = model.affine
+        self.relu = model.relu
+
+    def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
+        x = self.linear(input)  # B T D
+        # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+        # if cache_layer_name not in in_cache:
+        #     in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+        x, out_cache = self.fsmn_block(x, in_cache)
+        x = self.affine(x)
+        x = self.relu(x)
+        return x, out_cache
 
 
 class FsmnStack(nn.Sequential):
@@ -192,6 +212,12 @@
         self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
         self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
         self.softmax = nn.Softmax(dim=-1)
+        
+        # export onnx or torchscripts
+        if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE':
+            for i, d in enumerate(self.fsmn):
+                if isinstance(d, BasicBlock):
+                    self.fsmn[i] = BasicBlock_export(d)
 
     def fuse_modules(self):
         pass
@@ -218,6 +244,32 @@
 
         return x7
 
+    def export_forward(
+            self,
+            input: torch.Tensor,
+            *args,
+    ):
+        """
+        Args:
+            input (torch.Tensor): Input tensor (B, T, D)
+            in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
+            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+        """
+
+        x = self.in_linear1(input)
+        x = self.in_linear2(x)
+        x = self.relu(x)
+        # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
+        out_caches = list()
+        for i, d in enumerate(self.fsmn):
+            in_cache = args[i]
+            x, out_cache = d(x, in_cache)
+            out_caches.append(out_cache)
+        x = self.out_linear1(x)
+        x = self.out_linear2(x)
+        x = self.softmax(x)
+
+        return x, out_caches
 
 '''
 one deep fsmn layer
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index e29af57..9649ee8 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -284,6 +284,7 @@
 		encoder_class = tables.encoder_classes.get(encoder)
 		encoder = encoder_class(**encoder_conf)
 		self.encoder = encoder
+		self.encoder_conf = encoder_conf
 	
 	def ResetDetection(self, cache: dict = {}):
 		cache["stats"].continous_silence_frame_count = 0
@@ -642,6 +643,45 @@
 		
 		return results, meta_data
 	
+	def export(self, **kwargs):
+		self.forward = self._export_forward
+		
+		return self
+		
+	def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
+		scores, out_caches = self.encoder.export_forward(feats, *args)
+		return scores, out_caches
+	
+	def export_dummy_inputs(self, data_in=None, frame=30):
+		if data_in is None:
+			speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
+		else:
+			speech = None # Undo
+		
+		cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
+		in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+		in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+		in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+		in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+		
+		return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
+	
+	def export_input_names(self):
+		return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
+	
+	def export_output_names(self):
+		return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
+	
+	def export_dynamic_axes(self):
+		return {
+			'speech': {
+				1: 'feats_length'
+			},
+		}
+	
+	def export_name(self, ):
+		return "model.onnx"
+	
 	def DetectCommonFrames(self, cache: dict = {}) -> int:
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
 			return 0
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index 0f4973d..d83f571 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -218,6 +218,9 @@
     
         audio_mask = kwargs.get("audio_mask", None)
         audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
+        text_token_int = kwargs.get("text_token_int", None)
+        if audio_token_lengths is None:
+            audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
 
         batch = {"speech": speech, "speech_lengths": speech_lengths}
         enc, enc_lens = self.audio_encoder.encode(**batch)
@@ -259,7 +262,13 @@
             time1 = time.perf_counter()
             audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                             data_type=kwargs.get("data_type", "sound"),
-                                                            tokenizer=tokenizer)
+                                                            tokenizer=None)
+            if len(kwargs.get("data_type")) > 1:
+                audio_sample_list, text_token_int_list = audio_sample_list
+                text_token_int = text_token_int_list[0].replace(" ", "")
+                text_token_int = tokenizer.encode(text_token_int)
+            else:
+                text_token_int = None
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
             speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
@@ -272,7 +281,7 @@
         speech_lengths = speech_lengths.to(device=kwargs["device"])
         
         # Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
 
         # adaptor
         encoder_out = self.adaptor(encoder_out)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 4d9f5d8..ddcfb5a 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -155,7 +155,7 @@
                  tf2torch_tensor_name_prefix_tf="seq2seq/cif",
                  tail_mask=True,
                  ):
-        super(CifPredictorV2, self).__init__()
+        super().__init__()
 
         self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
         self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
@@ -375,6 +375,120 @@
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
         return predictor_alignments.detach(), predictor_alignments_length.detach()
 
+@tables.register("predictor_classes", "CifPredictorV2Export")
+class CifPredictorV2(torch.nn.Module):
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        
+        self.pad = model.pad
+        self.cif_conv1d = model.cif_conv1d
+        self.cif_output = model.cif_output
+        self.threshold = model.threshold
+        self.smooth_factor = model.smooth_factor
+        self.noise_threshold = model.noise_threshold
+        self.tail_threshold = model.tail_threshold
+    
+    def forward(self, hidden: torch.Tensor,
+                mask: torch.Tensor,
+                ):
+        alphas, token_num = self.forward_cnn(hidden, mask)
+        mask = mask.transpose(-1, -2).float()
+        mask = mask.squeeze(-1)
+        hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+        acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+        
+        return acoustic_embeds, token_num, alphas, cif_peak
+    
+    def forward_cnn(self, hidden: torch.Tensor,
+                    mask: torch.Tensor,
+                    ):
+        h = hidden
+        context = h.transpose(1, 2)
+        queries = self.pad(context)
+        output = torch.relu(self.cif_conv1d(queries))
+        output = output.transpose(1, 2)
+        
+        output = self.cif_output(output)
+        alphas = torch.sigmoid(output)
+        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+        mask = mask.transpose(-1, -2).float()
+        alphas = alphas * mask
+        alphas = alphas.squeeze(-1)
+        token_num = alphas.sum(-1)
+        
+        return alphas, token_num
+    
+    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+        b, t, d = hidden.size()
+        tail_threshold = self.tail_threshold
+        
+        zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+        ones_t = torch.ones_like(zeros_t)
+        
+        mask_1 = torch.cat([mask, zeros_t], dim=1)
+        mask_2 = torch.cat([ones_t, mask], dim=1)
+        mask = mask_2 - mask_1
+        tail_threshold = mask * tail_threshold
+        alphas = torch.cat([alphas, zeros_t], dim=1)
+        alphas = torch.add(alphas, tail_threshold)
+        
+        zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+        hidden = torch.cat([hidden, zeros], dim=1)
+        token_num = alphas.sum(dim=-1)
+        token_num_floor = torch.floor(token_num)
+        
+        return hidden, alphas, token_num_floor
+
+@torch.jit.script
+def cif_export(hidden, alphas, threshold: float):
+    batch_size, len_time, hidden_size = hidden.size()
+    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+    
+    # loop varss
+    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
+    frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
+    # intermediate vars along time
+    list_fires = []
+    list_frames = []
+    
+    for t in range(len_time):
+        alpha = alphas[:, t]
+        distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
+        
+        integrate += alpha
+        list_fires.append(integrate)
+        
+        fire_place = integrate >= threshold
+        integrate = torch.where(fire_place,
+                                integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
+                                integrate)
+        cur = torch.where(fire_place,
+                          distribution_completion,
+                          alpha)
+        remainds = alpha - cur
+        
+        frame += cur[:, None] * hidden[:, t, :]
+        list_frames.append(frame)
+        frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+                            remainds[:, None] * hidden[:, t, :],
+                            frame)
+    
+    fires = torch.stack(list_fires, 1)
+    frames = torch.stack(list_frames, 1)
+    
+    fire_idxs = fires >= threshold
+    frame_fires = torch.zeros_like(hidden)
+    max_label_len = frames[0, fire_idxs[0]].size(0)
+    for b in range(batch_size):
+        frame_fire = frames[b, fire_idxs[b]]
+        frame_len = frame_fire.size(0)
+        frame_fires[b, :frame_len, :] = frame_fire
+        
+        if frame_len >= max_label_len:
+            max_label_len = frame_len
+    frame_fires = frame_fires[:, :max_label_len, :]
+    return frame_fires, fires
+
 
 class mae_loss(torch.nn.Module):
 
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index ad321e4..ce018f4 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -581,6 +581,177 @@
 
         return y, new_cache
 
+class DecoderLayerSANMExport(torch.nn.Module):
+
+    def __init__(
+        self,
+        model
+    ):
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.src_attn = model.src_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
+        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
+        self.size = model.size
+
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn is not None:
+            tgt = self.norm2(tgt)
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + x
+
+        if self.src_attn is not None:
+            residual = x
+            x = self.norm3(x)
+            x = residual + self.src_attn(x, memory, memory_mask)
+
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
+class ParaformerSANMDecoderExport(torch.nn.Module):
+    def __init__(self, model,
+                 max_seq_len=512,
+                 model_name='decoder',
+                 onnx: bool = True, ):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+        
+        self.model = model
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
+        from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
+        from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
+        
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+            self.model.decoders[i] = DecoderLayerSANMExport(d)
+        
+        if self.model.decoders2 is not None:
+            for i, d in enumerate(self.model.decoders2):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+                self.model.decoders2[i] = DecoderLayerSANMExport(d)
+        
+        for i, d in enumerate(self.model.decoders3):
+            self.model.decoders3[i] = DecoderLayerSANMExport(d)
+        
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+    
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+        
+        return mask_3d_btd, mask_4d_bhlt
+    
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+        
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        
+        x = tgt
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.model.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+                x, tgt_mask, memory, memory_mask
+            )
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+            x, tgt_mask, memory, memory_mask
+        )
+        x = self.after_norm(x)
+        x = self.output_layer(x)
+        
+        return x, ys_in_lens
+    
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        memory = torch.randn(1, 100, enc_size)
+        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        cache = [
+            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+            for _ in range(cache_num)
+        ]
+        return (tgt, memory, pre_acoustic_embeds, cache)
+    
+    def is_optimizable(self):
+        return True
+    
+    def get_input_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+               + ['cache_%d' % i for i in range(cache_num)]
+    
+    def get_output_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['y'] \
+               + ['out_cache_%d' % i for i in range(cache_num)]
+    
+    def get_dynamic_axes(self):
+        ret = {
+            'tgt': {
+                0: 'tgt_batch',
+                1: 'tgt_length'
+            },
+            'memory': {
+                0: 'memory_batch',
+                1: 'memory_length'
+            },
+            'pre_acoustic_embeds': {
+                0: 'acoustic_embeds_batch',
+                1: 'acoustic_embeds_length',
+            }
+        }
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        ret.update({
+            'cache_%d' % d: {
+                0: 'cache_%d_batch' % d,
+                2: 'cache_%d_length' % d
+            }
+            for d in range(cache_num)
+        })
+        return ret
+
+
 
 @tables.register("decoder_classes", "ParaformerSANDecoder")
 class ParaformerSANDecoder(BaseTransformerDecoder):
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 51c9bdb..8739ed5 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -549,3 +549,80 @@
                 
         return results, meta_data
 
+    def export(
+        self,
+        max_seq_len=512,
+        **kwargs,
+    ):
+        onnx = kwargs.get("onnx", True)
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+        self.encoder = encoder_class(self.encoder, onnx=onnx)
+        
+        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
+        self.predictor = predictor_class(self.predictor, onnx=onnx)
+
+
+        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
+        self.decoder = decoder_class(self.decoder, onnx=onnx)
+        
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+        
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        self.forward = self._export_forward
+        
+        return self
+
+    def _export_forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        # batch = to_device(batch, device=self.device)
+    
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+        pre_token_length = pre_token_length.floor().type(torch.int32)
+    
+        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        # sample_ids = decoder_out.argmax(dim=-1)
+    
+        return decoder_out, pre_token_length
+
+    def export_dummy_inputs(self):
+        speech = torch.randn(2, 30, 560)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        return (speech, speech_lengths)
+
+
+    def export_input_names(self):
+        return ['speech', 'speech_lengths']
+
+    def export_output_names(self):
+        return ['logits', 'token_num']
+
+    def export_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+        }
+
+    def export_name(self, ):
+        return "model.onnx"
diff --git a/funasr/models/qwen_audio/audio.py b/funasr/models/qwen_audio/audio.py
new file mode 100644
index 0000000..1a37897
--- /dev/null
+++ b/funasr/models/qwen_audio/audio.py
@@ -0,0 +1,446 @@
+import base64
+import gzip
+from dataclasses import dataclass
+from typing import Dict, Iterable, Optional, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from subprocess import CalledProcessError, run, Popen, PIPE
+
+import os
+from functools import lru_cache
+from typing import Optional, Union
+from funasr.register import tables
+
+
+
+def exact_div(x, y):
+    assert x % y == 0
+    return x // y
+
+# hard-coded audio hyperparameters
+SAMPLE_RATE = 16000
+N_FFT = 400
+N_MELS = 80
+HOP_LENGTH = 160
+CHUNK_LENGTH = 30
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
+N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000 frames in a mel spectrogram input
+
+N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 20ms per audio token
+
+
+
+def get_T_after_cnn(L_in, dilation=1):
+    for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
+        L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
+        L_out = 1 + L_out // stride
+        L_in = L_out
+    return L_out
+
+def load_bytesio_audio(content, sr: int = SAMPLE_RATE):
+    cmd = [
+        "ffmpeg",
+        "-nostdin",
+        "-threads", "0",
+        "-i", "pipe:",
+        "-f", "s16le",
+        "-ac", "1",
+        "-acodec", "pcm_s16le",
+        "-ar", str(sr),
+        "pipe:"
+    ]
+    p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=-1)
+    out, _ = p.communicate(input=content)
+    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+
+def load_audio(file: str, sr: int = SAMPLE_RATE):
+    """
+    Open an audio file and read as mono waveform, resampling as necessary
+
+    Parameters
+    ----------
+    file: str
+        The audio file to open
+
+    sr: int
+        The sample rate to resample the audio if necessary
+
+    Returns
+    -------
+    A NumPy array containing the audio waveform, in float32 dtype.
+    """
+
+    # This launches a subprocess to decode audio while down-mixing
+    # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
+    # fmt: off
+    cmd = [
+        "ffmpeg",
+        "-nostdin",
+        "-threads", "0",
+        "-i", file,
+        "-f", "s16le",
+        "-ac", "1",
+        "-acodec", "pcm_s16le",
+        "-ar", str(sr),
+        "-"
+    ]
+    # fmt: on
+    try:
+        out = run(cmd, capture_output=True, check=True).stdout
+    except CalledProcessError as e:
+        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+
+    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+
+
+def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
+    """
+    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+    """
+    if torch.is_tensor(array):
+        if array.shape[axis] > length:
+            array = array.index_select(
+                dim=axis, index=torch.arange(length, device=array.device)
+            )
+
+        if array.shape[axis] < length:
+            pad_widths = [(0, 0)] * array.ndim
+            pad_widths[axis] = (0, length - array.shape[axis])
+            array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
+    else:
+        if array.shape[axis] > length:
+            array = array.take(indices=range(length), axis=axis)
+
+        if array.shape[axis] < length:
+            pad_widths = [(0, 0)] * array.ndim
+            pad_widths[axis] = (0, length - array.shape[axis])
+            array = np.pad(array, pad_widths)
+
+    return array
+
+def trim(array, length: int = N_SAMPLES, *, axis: int = -1):
+    """
+    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+    """
+    if torch.is_tensor(array):
+        if array.shape[axis] > length:
+            array = array.index_select(
+                dim=axis, index=torch.arange(length, device=array.device)
+            )
+    else:
+        if array.shape[axis] > length:
+            array = array.take(indices=range(length), axis=axis)
+    return array
+
+
+@lru_cache(maxsize=None)
+def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
+    """
+    load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+    Allows decoupling librosa dependency; saved using:
+
+        np.savez_compressed(
+            "mel_filters.npz",
+            mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+        )
+    """
+    assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+    with np.load(
+        os.path.join(os.path.dirname(__file__), "mel_filters.npz") # todo
+        # os.path.join("assets", "mel_filters.npz")
+    ) as f:
+        return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
+
+
+def log_mel_spectrogram(
+    audio: Union[str, np.ndarray, torch.Tensor],
+    n_mels: int = N_MELS,
+    padding: int = 0,
+    device: Optional[Union[str, torch.device]] = None,
+):
+    """
+    Compute the log-Mel spectrogram of
+
+    Parameters
+    ----------
+    audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
+        The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
+
+    n_mels: int
+        The number of Mel-frequency filters, only 80 is supported
+
+    padding: int
+        Number of zero samples to pad to the right
+
+    device: Optional[Union[str, torch.device]]
+        If given, the audio tensor is moved to this device before STFT
+
+    Returns
+    -------
+    torch.Tensor, shape = (80, n_frames)
+        A Tensor that contains the Mel spectrogram
+    """
+    if not torch.is_tensor(audio):
+        if isinstance(audio, str):
+            audio = load_audio(audio)
+        audio = torch.from_numpy(audio)
+
+    if device is not None:
+        audio = audio.to(device)
+    if padding > 0:
+        audio = F.pad(audio, (0, padding))
+    window = torch.hann_window(N_FFT).to(audio.device)
+    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
+    magnitudes = stft[..., :-1].abs() ** 2
+
+    filters = mel_filters(audio.device, n_mels)
+    mel_spec = filters @ magnitudes
+
+    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+    log_spec = (log_spec + 4.0) / 4.0
+    return log_spec
+
+
+@dataclass
+class ModelDimensions:
+    n_mels: int
+    n_audio_ctx: int
+    n_audio_state: int
+    n_audio_head: int
+    n_audio_layer: int
+    n_vocab: int
+    n_text_ctx: int
+    n_text_state: int
+    n_text_head: int
+    n_text_layer: int
+
+
+class LayerNorm(nn.LayerNorm):
+    def forward(self, x: Tensor) -> Tensor:
+        # return super().forward(x.float()).type(x.dtype)
+        return super().forward(x).type(x.dtype)
+
+
+
+
+class Linear(nn.Linear):
+    def forward(self, x: Tensor) -> Tensor:
+        return F.linear(
+            x,
+            self.weight.to(x.dtype),
+            None if self.bias is None else self.bias.to(x.dtype),
+        )
+
+
+class Conv1d(nn.Conv1d):
+    def _conv_forward(
+        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
+    ) -> Tensor:
+        return super()._conv_forward(
+            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
+        )
+
+
+def sinusoids(length, channels, max_timescale=10000):
+    """Returns sinusoids for positional embedding"""
+    assert channels % 2 == 0
+    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(self, n_state: int, n_head: int):
+        super().__init__()
+        self.n_head = n_head
+        self.query = Linear(n_state, n_state)
+        self.key = Linear(n_state, n_state, bias=False)
+        self.value = Linear(n_state, n_state)
+        self.out = Linear(n_state, n_state)
+
+    def forward(
+        self,
+        x: Tensor,
+        xa: Optional[Tensor] = None,
+        mask: Optional[Tensor] = None,
+        kv_cache: Optional[dict] = None,
+    ):
+        q = self.query(x)
+
+        if kv_cache is None or xa is None or self.key not in kv_cache:
+            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+            # otherwise, perform key/value projections for self- or cross-attention as usual.
+            k = self.key(x if xa is None else xa)
+            v = self.value(x if xa is None else xa)
+        else:
+            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+            k = kv_cache[self.key]
+            v = kv_cache[self.value]
+
+        wv, qk = self.qkv_attention(q, k, v, mask)
+        return self.out(wv), qk
+
+    def qkv_attention(
+        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+    ):
+        n_batch, n_ctx, n_state = q.shape
+        scale = (n_state // self.n_head) ** -0.25
+        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+
+        qk = q @ k
+        if mask is not None:
+            qk += mask
+
+        w = F.softmax(qk, dim=-1).to(q.dtype)
+        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+        super().__init__()
+
+        self.attn = MultiHeadAttention(n_state, n_head)
+        self.attn_ln = LayerNorm(n_state)
+
+        self.cross_attn = (
+            MultiHeadAttention(n_state, n_head) if cross_attention else None
+        )
+        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+        n_mlp = n_state * 4
+        self.mlp = nn.Sequential(
+            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+        )
+        self.mlp_ln = LayerNorm(n_state)
+
+    def forward(
+        self,
+        x: Tensor,
+        xa: Optional[Tensor] = None,
+        mask: Optional[Tensor] = None,
+        kv_cache: Optional[dict] = None,
+    ):
+        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+        if self.cross_attn:
+            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+        x = x + self.mlp(self.mlp_ln(x))
+        return x
+
+@tables.register("encoder_classes", "QwenAudioEncoder")
+class QwenAudioEncoder(nn.Module):
+    def __init__(
+            self,
+            n_mels: int,
+            n_ctx: int,
+            n_state: int,
+            n_head: int,
+            n_layer: int,
+            output_dim: int = 512,
+            avg_pool: bool = True,
+            add_audio_bos_eos_token: bool = True,
+            **kwargs
+    ):
+        super().__init__()
+        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
+        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+        )
+        self.ln_post = LayerNorm(n_state)
+
+        if avg_pool:
+            self.avg_pooler = nn.AvgPool1d(2, stride=2)
+        else:
+            self.avg_pooler = None
+        self.proj = nn.Linear(n_state, output_dim)
+        if add_audio_bos_eos_token:
+            self.audio_bos_eos_token = nn.Embedding(2, output_dim)
+        else:
+            self.audio_bos_eos_token = None
+        self.output_dim = output_dim
+        self.n_head = n_head
+
+    def forward(self, x: Tensor, padding_mask: Tensor=None, audio_lengths: Tensor=None):
+        """
+        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+            the mel spectrogram of the audio
+        """
+        x = x.to(dtype=self.conv1.weight.dtype,
+                 device=self.conv1.weight.device)
+        if audio_lengths is not None:
+            input_mel_len = audio_lengths[:,0] * 2
+            max_mel_len_in_batch = input_mel_len.max()
+            x = x[:, :, :max_mel_len_in_batch]
+        x = F.gelu(self.conv1(x))
+        x = F.gelu(self.conv2(x))
+        x = x.permute(0, 2, 1)  # B, L, D
+        bsz = x.size(0)
+        src_len = x.size(1)
+
+
+        self.input_positional_embedding = self.positional_embedding[:src_len]
+        assert x.shape[1:] == self.input_positional_embedding.shape, f"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}"
+        x = (x + self.input_positional_embedding).to(x.dtype)
+        if padding_mask is not None:
+            padding_mask = padding_mask.to(dtype=self.conv1.weight.dtype,
+                     device=self.conv1.weight.device)
+            batch_src_len = padding_mask.size(1)
+            x = x[:, :batch_src_len, :]
+            padding_mask = padding_mask.view(
+                bsz, -1, batch_src_len
+            )
+            padding_mask_ = padding_mask.all(1)
+            x[padding_mask_] = 0
+            key_padding_mask = padding_mask_.view(bsz, 1, 1, batch_src_len). \
+                expand(-1, self.n_head, -1, -1).reshape(bsz, self.n_head, 1, batch_src_len)
+            new_padding_mask = torch.zeros_like(key_padding_mask, dtype=x.dtype)
+            padding_mask = new_padding_mask.masked_fill(key_padding_mask, float("-inf"))
+
+        for block in self.blocks:
+            x = block(x, mask=padding_mask)
+
+
+        if self.avg_pooler:
+            x = x.permute(0, 2, 1)
+            x = self.avg_pooler(x)
+            x = x.permute(0, 2, 1)
+
+
+        x = self.ln_post(x)
+        x = self.proj(x)
+
+        if self.audio_bos_eos_token is not None:
+            bos = self.audio_bos_eos_token.weight[0][None, :]
+            eos = self.audio_bos_eos_token.weight[1][None, :]
+        else:
+            bos, eos = None, None
+        return x, bos, eos
+
+    def encode(self, input_audios: Tensor, input_audio_lengths: Tensor, audio_span_tokens: List):
+        real_input_audio_lens = input_audio_lengths[:, 0].tolist()
+        max_len_in_batch = max(real_input_audio_lens)
+        padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(dtype=self.conv1.weight.dtype,
+                                                                               device=self.conv1.weight.device)
+        for index in range(len(input_audios)):
+            padding_mask[index, :input_audio_lengths[index][0].item()] = 0
+        x, bos, eos = self(input_audios, padding_mask,input_audio_lengths)
+        output_audios = []
+        for i in range(len(audio_span_tokens)):
+            audio_span = audio_span_tokens[i]
+            audio = x[i][:audio_span-2]
+            if bos is not None:
+                audio = torch.concat([bos, audio, eos])
+            assert len(audio) == audio_span
+            output_audios.append(audio)
+        return output_audios
diff --git a/funasr/models/qwen_audio/model.py b/funasr/models/qwen_audio/model.py
index 86b4105..045cfe1 100644
--- a/funasr/models/qwen_audio/model.py
+++ b/funasr/models/qwen_audio/model.py
@@ -60,7 +60,7 @@
         inputs = self.tokenizer(query, return_tensors='pt', audio_info=audio_info)
         inputs = inputs.to(self.model.device)
         pred = self.model.generate(**inputs, audio_info=audio_info)
-        response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
+        response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
 
         results = []
         result_i = {"key": key[0], "text": response}
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 10f0a3b..09a1f07 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -303,6 +303,64 @@
         att_outs = self.forward_attention(v_h, scores, None)
         return att_outs + fsmn_memory, cache
 
+class MultiHeadedAttentionSANMExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_out = model.linear_out
+        self.linear_q_k_v = model.linear_q_k_v
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, mask):
+        mask_3d_btd, mask_4d_bhlt = mask
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+        q_h = q_h * self.d_k**(-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+        return att_outs + fsmn_memory
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x):
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = self.transpose_for_scores(q)
+        k_h = self.transpose_for_scores(k)
+        v_h = self.transpose_for_scores(v)
+        return q_h, k_h, v_h, v
+
+    def forward_fsmn(self, inputs, mask):
+        # b, t, d = inputs.size()
+        # mask = torch.reshape(mask, (b, -1, 1))
+        inputs = inputs * mask
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x = x + inputs
+        x = x * mask
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
 
 
 class MultiHeadedAttentionSANMDecoder(nn.Module):
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index 069c527..561179b 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -456,3 +456,135 @@
             return (xs_pad, intermediate_outs), None, None
         return xs_pad, ilens, None
 
+class EncoderLayerSANMExport(nn.Module):
+    def __init__(
+        self,
+        model,
+    ):
+        """Construct an EncoderLayer object."""
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.in_size = model.in_size
+        self.size = model.size
+
+    def forward(self, x, mask):
+
+        residual = x
+        x = self.norm1(x)
+        x = self.self_attn(x, mask)
+        if self.in_size == self.size:
+            x = x + residual
+        residual = x
+        x = self.norm2(x)
+        x = self.feed_forward(x)
+        x = x + residual
+
+        return x, mask
+
+
+@tables.register("encoder_classes", "SANMEncoderExport")
+class SANMEncoderExport(nn.Module):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name='encoder',
+        onnx: bool = True,
+    ):
+        super().__init__()
+        self.embed = model.embed
+        if isinstance(self.embed, StreamSinusoidalPositionEncoder):
+            self.embed = None
+        self.model = model
+        self.feats_dim = feats_dim
+        self._output_size = model._output_size
+
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+        
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
+        if hasattr(model, 'encoders0'):
+            for i, d in enumerate(self.model.encoders0):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+                self.model.encoders0[i] = EncoderLayerSANMExport(d)
+        
+        for i, d in enumerate(self.model.encoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+            self.model.encoders[i] = EncoderLayerSANMExport(d)
+        
+        self.model_name = model_name
+        self.num_heads = model.encoders[0].self_attn.h
+        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+    
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+        
+        return mask_3d_btd, mask_4d_bhlt
+    
+    def forward(self,
+                speech: torch.Tensor,
+                speech_lengths: torch.Tensor,
+                online: bool = False
+                ):
+        if not online:
+            speech = speech * self._output_size ** 0.5
+        mask = self.make_pad_mask(speech_lengths)
+        mask = self.prepare_mask(mask)
+        if self.embed is None:
+            xs_pad = speech
+        else:
+            xs_pad = self.embed(speech)
+        
+        encoder_outs = self.model.encoders0(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        
+        encoder_outs = self.model.encoders(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        
+        xs_pad = self.model.after_norm(xs_pad)
+        
+        return xs_pad, speech_lengths
+    
+    def get_output_size(self):
+        return self.model.encoders[0].size
+    
+    def get_dummy_inputs(self):
+        feats = torch.randn(1, 100, self.feats_dim)
+        return (feats)
+    
+    def get_input_names(self):
+        return ['feats']
+    
+    def get_output_names(self):
+        return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
+    
+    def get_dynamic_axes(self):
+        return {
+            'feats': {
+                1: 'feats_length'
+            },
+            'encoder_out': {
+                1: 'enc_out_length'
+            },
+            'predictor_weight': {
+                1: 'pre_out_length'
+            }
+            
+        }
+
diff --git a/funasr/models/sanm/multihead_att.py b/funasr/models/sanm/multihead_att.py
new file mode 100644
index 0000000..5ef36ed
--- /dev/null
+++ b/funasr/models/sanm/multihead_att.py
@@ -0,0 +1,243 @@
+import os
+import math
+
+import torch
+import torch.nn as nn
+
+
+class MultiHeadedAttentionSANMExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_out = model.linear_out
+        self.linear_q_k_v = model.linear_q_k_v
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, mask):
+        mask_3d_btd, mask_4d_bhlt = mask
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+        q_h = q_h * self.d_k**(-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+        return att_outs + fsmn_memory
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x):
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = self.transpose_for_scores(q)
+        k_h = self.transpose_for_scores(k)
+        v_h = self.transpose_for_scores(v)
+        return q_h, k_h, v_h, v
+
+    def forward_fsmn(self, inputs, mask):
+        # b, t, d = inputs.size()
+        # mask = torch.reshape(mask, (b, -1, 1))
+        inputs = inputs * mask
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x = x + inputs
+        x = x * mask
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
+    x = x * mask
+    x = x.transpose(1, 2)
+    if cache is None:
+        x = pad_fn(x)
+    else:
+        x = torch.cat((cache, x), dim=2)
+        cache = x[:, :, -(kernel_size-1):]
+    return x, cache
+
+
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
+    import torch.fx
+    torch.fx.wrap('preprocess_for_attn')
+
+
+class MultiHeadedAttentionSANMDecoderExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.fsmn_block = model.fsmn_block
+        self.pad_fn = model.pad_fn
+        self.kernel_size = model.kernel_size
+        self.attn = None
+
+    def forward(self, inputs, mask, cache=None):
+        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+
+        x = x + inputs
+        x = x * mask
+        return x, cache
+
+
+class MultiHeadedAttentionCrossAttExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_q = model.linear_q
+        self.linear_k_v = model.linear_k_v
+        self.linear_out = model.linear_out
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+
+    def forward(self, x, memory, memory_mask):
+        q, k, v = self.forward_qkv(x, memory)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, memory_mask)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, x, memory):
+        q = self.linear_q(x)
+
+        k_v = self.linear_k_v(memory)
+        k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
+        q = self.transpose_for_scores(q)
+        k = self.transpose_for_scores(k)
+        v = self.transpose_for_scores(v)
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+class OnnxMultiHeadedAttention(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.d_k = model.d_k
+        self.h = model.h
+        self.linear_q = model.linear_q
+        self.linear_k = model.linear_k
+        self.linear_v = model.linear_v
+        self.linear_out = model.linear_out
+        self.attn = None
+        self.all_head_size = self.h * self.d_k
+    
+    def forward(self, query, key, value, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward_qkv(self, query, key, value):
+        q = self.linear_q(query)
+        k = self.linear_k(key)
+        v = self.linear_v(value)
+        q = self.transpose_for_scores(q)
+        k = self.transpose_for_scores(k)
+        v = self.transpose_for_scores(v)
+        return q, k, v
+    
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+
+
+class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
+    def __init__(self, model):
+        super().__init__(model)
+        self.linear_pos = model.linear_pos
+        self.pos_bias_u = model.pos_bias_u
+        self.pos_bias_v = model.pos_bias_v
+    
+    def forward(self, query, key, value, pos_emb, mask):
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, time1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+    def rel_shift(self, x):
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)[
+            :, :, :, : x.size(-1) // 2 + 1
+        ]  # only keep the positions from 0 to time2
+        return x
+
+    def forward_attention(self, value, scores, mask):
+        scores = scores + mask
+
+        self.attn = torch.softmax(scores, dim=-1)
+        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
+        
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+        return self.linear_out(context_layer)  # (batch, time1, d_model)
+        
diff --git a/funasr/models/whisper/template.yaml b/funasr/models/whisper/template.yaml
index 40b902c..568b948 100644
--- a/funasr/models/whisper/template.yaml
+++ b/funasr/models/whisper/template.yaml
@@ -11,29 +11,28 @@
     lsm_weight: 0.1
     length_normalized_loss: true
     hub: funasr # openai
-    init_param_path: null # large-v2 or large-v3 if hub == "openai"
 
 
 
 # only use for hub == funasr,
-#  if hub == openai, whisper_dims is automaticall download
-whisper_dims:
-    'n_mels': 80
-    'n_vocab': 51865
-    'n_audio_ctx': 1500
-    'n_audio_state': 1280
-    'n_audio_head': 20
-    'n_audio_layer': 32
-    'n_text_ctx': 448
-    'n_text_state': 1280
-    'n_text_head': 20
-    'n_text_layer': 32
+#  if hub == openai, dims is automaticall download
+dims:
+    n_mels: 128
+    n_vocab: 51866
+    n_audio_ctx: 1500
+    n_audio_state: 1280
+    n_audio_head: 20
+    n_audio_layer: 32
+    n_text_ctx: 448
+    n_text_state: 1280
+    n_text_head: 20
+    n_text_layer: 32
 
 # frontend related
 frontend: WhisperFrontend
 frontend_conf:
     fs: 16000
-    n_mels: 80
+    n_mels: ${dims.n_mels}
     do_pad_trim: true
 
 tokenizer: WhisperTokenizer
@@ -41,6 +40,6 @@
   language: null
   task: transcribe
   is_multilingual: true
-  num_languages: 99
+  num_languages: 100
 
-scope_map: ['none', "model."]
+scope_map: [none, "model."]
\ No newline at end of file
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
new file mode 100644
index 0000000..640be05
--- /dev/null
+++ b/funasr/utils/export_utils.py
@@ -0,0 +1,74 @@
+import os
+import torch
+
+def export_onnx(model,
+                data_in=None,
+				type: str = "onnx",
+				quantize: bool = False,
+				fallback_num: int = 5,
+				calib_num: int = 100,
+				opset_version: int = 14,
+				**kwargs):
+	model_scripts = model.export(**kwargs)
+	export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
+	os.makedirs(export_dir, exist_ok=True)
+	
+	if not isinstance(model_scripts, (list, tuple)):
+		model_scripts = (model_scripts,)
+	for m in model_scripts:
+		m.eval()
+		_onnx(m,
+		      data_in=data_in,
+		      type=type,
+		      quantize=quantize,
+		      fallback_num=fallback_num,
+		      calib_num=calib_num,
+		      opset_version=opset_version,
+		      export_dir=export_dir,
+		      **kwargs
+		      )
+		print("output dir: {}".format(export_dir))
+	
+	return export_dir
+	
+def _onnx(model,
+			data_in=None,
+			quantize: bool = False,
+			opset_version: int = 14,
+			export_dir:str = None,
+			**kwargs):
+	
+	dummy_input = model.export_dummy_inputs()
+	
+	verbose = kwargs.get("verbose", False)
+	
+	export_name = model.export_name() if hasattr(model, "export_name") else "model.onnx"
+	model_path = os.path.join(export_dir, export_name)
+	torch.onnx.export(
+		model,
+		dummy_input,
+		model_path,
+		verbose=verbose,
+		opset_version=opset_version,
+		input_names=model.export_input_names(),
+		output_names=model.export_output_names(),
+		dynamic_axes=model.export_dynamic_axes()
+	)
+	
+	if quantize:
+		from onnxruntime.quantization import QuantType, quantize_dynamic
+		import onnx
+		quant_model_path = model_path.replace(".onnx", "_quant.onnx")
+		if not os.path.exists(quant_model_path):
+			onnx_model = onnx.load(model_path)
+			nodes = [n.name for n in onnx_model.graph.node]
+			nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
+			quantize_dynamic(
+				model_input=model_path,
+				model_output=quant_model_path,
+				op_types_to_quantize=['MatMul'],
+				per_channel=True,
+				reduce_range=False,
+				weight_type=QuantType.QUInt8,
+				nodes_to_exclude=nodes_to_exclude,
+			)
\ No newline at end of file
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index e9050f5..a08f263 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -46,3 +46,10 @@
                 value = value.split(' ')
             ret_dict.append((key, value))
         return ret_dict
+
+def deep_update(original, update):
+    for key, value in update.items():
+        if isinstance(value, dict) and key in original:
+            deep_update(original[key], value)
+        else:
+            original[key] = value
\ No newline at end of file
diff --git a/funasr/utils/torch_function.py b/funasr/utils/torch_function.py
new file mode 100644
index 0000000..a078a7e
--- /dev/null
+++ b/funasr/utils/torch_function.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+
+class MakePadMask(nn.Module):
+    def __init__(self, max_seq_len=512, flip=True):
+        super().__init__()
+        if flip:
+            self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
+        else:
+            self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
+    
+    def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
+        """Make mask tensor containing indices of padded part.
+        This implementation creates the same mask tensor with original make_pad_mask,
+        which can be converted into onnx format.
+        Dimension length of xs should be 2 or 3.
+        """
+        if length_dim == 0:
+            raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+        if xs is not None and len(xs.shape) == 3:
+            if length_dim == 1:
+                lengths = lengths.unsqueeze(1).expand(
+                    *xs.transpose(1, 2).shape[:2])
+            else:
+                lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
+
+        if maxlen is not None:
+            m = maxlen
+        elif xs is not None:
+            m = xs.shape[-1]
+        else:
+            m = torch.max(lengths)
+
+        mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
+
+        if length_dim == 1:
+            return mask.transpose(1, 2)
+        else:
+            return mask
+
+class sequence_mask(nn.Module):
+    def __init__(self, max_seq_len=512, flip=True):
+        super().__init__()
+    
+    def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
+        if max_seq_len is None:
+            max_seq_len = lengths.max()
+        row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
+        matrix = torch.unsqueeze(lengths, dim=-1)
+        mask = row_vector < matrix
+        
+        return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+
+def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
+    if out is None:
+        denom = input.norm(p, dim, keepdim=True).expand_as(input)
+        return input / denom
+    else:
+        denom = input.norm(p, dim, keepdim=True).expand_as(input)
+        return torch.div(input, denom, out=out)
+
+def subsequent_mask(size: torch.Tensor):
+    return torch.ones(size, size).tril()
+
+
+def MakePadMask_test():
+    feats_length = torch.tensor([10]).type(torch.long)
+    mask_fn = MakePadMask()
+    mask = mask_fn(feats_length)
+    print(mask)
+
+
+if __name__ == '__main__':
+    MakePadMask_test()
\ No newline at end of file

--
Gitblit v1.9.1