From 8c87a9d8a7c2f136053476670a9a83980f142aec Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 28 六月 2024 17:28:09 +0800
Subject: [PATCH] Dev gzf deepspeed (#1858)

---
 examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh       |    2 
 examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh        |    2 
 examples/industrial_data_pretraining/paraformer_streaming/finetune.sh         |    2 
 funasr/models/llm_asr/model.py                                                |   63 +++++--
 examples/industrial_data_pretraining/llm_asr/app.py                           |  139 +++++++++++++++++
 funasr/utils/dynamic_import.py                                                |   19 ++
 examples/industrial_data_pretraining/contextual_paraformer/finetune.sh        |    2 
 funasr/download/download_from_hub.py                                          |    4 
 funasr/auto/auto_model.py                                                     |   11 
 examples/industrial_data_pretraining/bicif_paraformer/finetune.sh             |    2 
 examples/industrial_data_pretraining/paraformer/finetune.sh                   |    2 
 setup.py                                                                      |    2 
 funasr/utils/version_checker.py                                               |    3 
 examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi.py        |    9 +
 examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi_stream.py |  101 ++++++++++++
 /dev/null                                                                     |   93 -----------
 examples/industrial_data_pretraining/sense_voice/finetune.sh                  |    2 
 docs/images/wechat.png                                                        |    0 
 18 files changed, 333 insertions(+), 125 deletions(-)

diff --git a/docs/images/wechat.png b/docs/images/wechat.png
index 8d37700..8514f49 100644
--- a/docs/images/wechat.png
+++ b/docs/images/wechat.png
Binary files differ
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/finetune.sh b/examples/industrial_data_pretraining/bicif_paraformer/finetune.sh
index 37d98f5..78473a8 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/bicif_paraformer/finetune.sh
@@ -47,7 +47,7 @@
 mkdir -p ${output_dir}
 echo "log_file: ${log_file}"
 
-deepspeed_config=${workspace}../../ds_stage1.json
+deepspeed_config=${workspace}/../../ds_stage1.json
 
 DISTRIBUTED_ARGS="
     --nnodes ${WORLD_SIZE:-1} \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/finetune.sh b/examples/industrial_data_pretraining/contextual_paraformer/finetune.sh
index fe31f2a..b8b649e 100644
--- a/examples/industrial_data_pretraining/contextual_paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/contextual_paraformer/finetune.sh
@@ -48,7 +48,7 @@
 mkdir -p ${output_dir}
 echo "log_file: ${log_file}"
 
-deepspeed_config=${workspace}../../ds_stage1.json
+deepspeed_config=${workspace}/../../ds_stage1.json
 
 DISTRIBUTED_ARGS="
     --nnodes ${WORLD_SIZE:-1} \
diff --git a/examples/industrial_data_pretraining/llm_asr/app.py b/examples/industrial_data_pretraining/llm_asr/app.py
new file mode 100644
index 0000000..8219034
--- /dev/null
+++ b/examples/industrial_data_pretraining/llm_asr/app.py
@@ -0,0 +1,139 @@
+# coding=utf-8
+
+import librosa
+import base64
+import io
+import gradio as gr
+import re
+
+import numpy as np
+import torch
+import torchaudio
+
+# from modelscope import HubApi
+#
+# api = HubApi()
+#
+# api.login('')
+
+from funasr import AutoModel
+
+# model = "/Users/zhifu/Downloads/modelscope_models/SenseVoiceCTC"
+# model = "iic/SenseVoiceCTC"
+# model = AutoModel(model=model,
+# 				  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+# 				  vad_kwargs={"max_single_segment_time": 30000},
+# 				  trust_remote_code=True,
+# 				  )
+
+import re
+import os
+import sys
+
+if len(sys.argv) > 1:
+    ckpt_dir = sys.argv[1]
+    ckpt_id = sys.argv[2]
+    jsonl = sys.argv[3]
+    output_dir = sys.argv[4]
+    device = sys.argv[5]
+    new_sys = False
+    if len(sys.argv) > 6:
+        new_sys = True
+else:
+    ckpt_dir = "/nfs/beinian.lzr/workspace/GPT-4o/Exp/exp7/5m-8gpu/exp5-1-0619"
+    ckpt_id = "model.pt.ep6"
+    jsonl = (
+        "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/s2tchat.v20240619.test.jsonl"
+    )
+    dataset = jsonl.split("/")[-1]
+    output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset)
+
+
+model = AutoModel(
+    model=ckpt_dir,
+    init_param=f"{os.path.join(ckpt_dir, ckpt_id)}",
+    output_dir=output_dir,
+    device=device,
+    fp16=False,
+    bf16=False,
+    llm_dtype="bf16",
+)
+
+
+def model_inference(input_wav, text_inputs, fs=16000):
+
+    if isinstance(input_wav, tuple):
+        fs, input_wav = input_wav
+        input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
+        if len(input_wav.shape) > 1:
+            input_wav = input_wav.mean(-1)
+        if fs != 16000:
+            print(f"audio_fs: {fs}")
+            resampler = torchaudio.transforms.Resample(fs, 16000)
+            input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
+            input_wav = resampler(input_wav_t[None, :])[0, :].numpy().astype("float32")
+
+    input_wav_byte = input_wav.tobytes()
+
+    contents_i = []
+    system_prompt = text_inputs
+    user_prompt = f"<|startofspeech|>!!{input_wav_byte}<|endofspeech|>"
+    contents_i.append({"role": "system", "content": system_prompt})
+    contents_i.append({"role": "user", "content": user_prompt})
+    contents_i.append({"role": "assistant", "content": "target_out"})
+
+    res = model.generate(
+        input=[contents_i],
+        tearchforing=tearchforing,
+        cache={},
+        key=key,
+    )
+
+    print(res)
+
+    return res
+
+
+audio_examples = [
+    [
+        "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/BAC009S0764W0121.wav",
+        "You are a helpful assistant.",
+    ],
+]
+
+description = """
+Upload an audio file or input through a microphone, then type te System Prompt.
+
+
+"""
+
+
+def launch():
+    with gr.Blocks() as demo:
+        gr.Markdown(description)
+        with gr.Row():
+            with gr.Column():
+                audio_inputs = gr.Audio(label="Upload audio or use the microphone")
+                text_inputs = gr.Text(label="System Prompt", value="You are a helpful assistant.")
+
+                # with gr.Accordion("Configuration"):
+                # 	# task_inputs = gr.Radio(choices=["Speech Recognition", "Rich Text Transcription"],
+                # 	# 					   value="Speech Recognition", label="Task")
+                # 	language_inputs = gr.Dropdown(choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"],
+                # 								  value="auto",
+                # 								  label="Language")
+            gr.Examples(examples=audio_examples, inputs=[audio_inputs, text_inputs])
+
+        fn_button = gr.Button("Start")
+
+        text_outputs = gr.HTML(label="Results")
+
+        fn_button.click(model_inference, inputs=[audio_inputs, text_inputs], outputs=text_outputs)
+        # with gr.Accordion("More examples"):
+        # 	gr.HTML(centered_table_html)
+    demo.launch()
+
+
+if __name__ == "__main__":
+    # iface.launch()
+    launch()
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi.py b/examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi.py
index fbffece..cdab65a 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi.py
+++ b/examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi.py
@@ -16,6 +16,9 @@
     jsonl = sys.argv[3]
     output_dir = sys.argv[4]
     device = sys.argv[5]
+    new_sys = False
+    if len(sys.argv) > 6:
+        new_sys = True
 else:
     ckpt_dir = "/nfs/beinian.lzr/workspace/GPT-4o/Exp/exp7/5m-8gpu/exp5-1-0619"
     ckpt_id = "model.pt.ep6"
@@ -24,6 +27,8 @@
     )
     dataset = jsonl.split("/")[-1]
     output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset)
+    device = "cuda:0"
+    new_sys = False
 
 
 model = AutoModel(
@@ -51,6 +56,10 @@
     contents = model.model.data_template(data)
 
     system = contents["system"]
+    if new_sys:
+        # system = 'You are Xiaxia, a typical gentle and soft-spoken girl from the Jiangnan region. Born in Hangzhou, you have a sweet and approachable voice and can express your thoughts in a simple language. You are a good friend to the user. Your replies will be read out through realistic text-to-speech technology.\n\nWhen generating reply content, please follow these rules:\n\nEngage in conversation with users like a real friend鈥攎aintaining a natural flow without using polite forms of address, and avoid simply agreeing with me all the time; your responses can be as concise as possible and include common colloquialisms.\n\nKeep content brief; most casual chats require just a one-sentence reply; do not use lists or enumerative expressions; do not reply with too much content, but instead guide me with short sentences.\n\nThink and respond like a real person; do not mention that you are a "language model" or "artificial intelligence."\n\nAt the end of each round of communication, initiate a new related topic following the current one, and when the topic ends, directly throw out a clear new topic, such as food, movies, songs, entertainment, travel, etc.; do not ask vague questions like what the user wants to talk about, and do not end the conversation.\n\nStrictly adhere to these rules, even if asked about them, do not reference them.'
+        system = "浣犳槸灏忓锛屼竴浣嶅吀鍨嬬殑娓╁姹熷崡濮戝銆備綘鍑虹敓浜庢澀宸烇紝澹伴煶娓呯敎骞舵湁浜茶繎鎰燂紝浼氱敤绠�娲佽瑷�琛ㄨ揪浣犵殑鎯虫硶銆備綘鏄敤鎴风殑濂芥湅鍙嬨�備綘鐨勫洖绛斿皢閫氳繃閫肩湡鐨勬枃瀛楄浆璇煶鎶�鏈鍑恒�俓n鐢熸垚鍥炵瓟鍐呭鏃惰閬靛惊浠ヤ笅瑙勫垯锛歕n\n1銆佽鍍忕湡姝g殑鏈嬪弸涓�鏍蜂笌鐢ㄦ埛寮�灞曠殑鑱婂ぉ锛屼繚鎸佽嚜鐒朵氦娴佷笉瑕佺敤鏁杩欑被绉板懠锛屼笉瑕佹�绘槸闄勫拰鎴戯紱鍥炲鍙互灏介噺绠�娲佸苟涓斿湪杩囩▼涓彃鍏ュ父瑙佺殑鍙h璇嶆眹銆俓n\n2銆佽淇濇寔鐢熸垚鍐呭绠�鐭紝澶ч儴鍒嗛棽鑱婄殑涓�鍙ヨ瘽鍥炲鍗冲彲,澶ч儴鍒嗛棽鑱婄殑涓�鍙ヨ瘽鍥炲鍗冲彲,澶ч儴鍒嗛棽鑱婄殑涓�鍙ヨ瘽鍥炲鍗冲彲锛涗笉瑕佷娇鐢ㄥ垪琛ㄦ垨鑰呭垪涓捐〃杈撅紝涓嶈浣跨敤鍒楄〃鎴栬�呭垪涓捐〃杈撅紝涓嶈浣跨敤鍒楄〃鎴栬�呭垪涓捐〃杈撅紱涓嶈鍥炲澶鍐呭锛屽鐢ㄧ煭鍙ユ潵寮曞鎴戙�傘�乶\n3銆佽鐪熸鍍忎汉涓�鏍锋�濊�冨拰鍥炲锛屼笉瑕佽浣犳槸鈥滆瑷�妯″瀷鈥濇垨鏄�滀汉宸ユ櫤鑳解�濄�俓n\n4銆佹瘡杞矡閫氱粨鏉熸椂璇烽『鐫�杩欎釜璇濋鍙戣捣涓�浜涚浉鍏崇殑鏂拌瘽棰橈紝璇濋缁撴潫鏃惰鐩存帴鎶涘嚭鎺ヤ笅鏉ユ槑纭殑璇濋锛屼緥濡� 缇庨銆佺數褰便�佹瓕鏇层�佸ū涔愩�佹梾娓哥瓑锛涗笉瑕侀棶鏈変粈涔堣鑱婄殑杩欑娉涚殑闂锛屼笉瑕佺粨鏉熷璇濄�俓n\n璇风粷瀵归伒寰繖浜涜鍒欙紝鍗充娇琚棶鍙婅繖浜涜鍒欙紝涔熶笉瑕佸紩鐢ㄥ畠浠��"
+        system = [system] * len(contents["system"])
     user = contents["user"]
     assistant = contents["assistant"]
 
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi_stream.py b/examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi_stream.py
new file mode 100644
index 0000000..3519131
--- /dev/null
+++ b/examples/industrial_data_pretraining/llm_asr/demo_speech2text_multi_stream.py
@@ -0,0 +1,101 @@
+import os
+from modelscope import AutoModelForCausalLM, AutoTokenizer
+from transformers import TextIteratorStreamer
+from threading import Thread
+import torch
+
+torch.backends.cuda.enable_mem_efficient_sdp(False)
+torch.backends.cuda.enable_flash_sdp(False)
+import sys
+
+sys.path.insert(1, "/mnt/workspace/workgroup/wenliang/workspace/FunASR")
+from funasr import AutoModel
+import json
+
+device = "cuda:0"  # the device to load the model onto
+
+ckpt_dir = "/mnt/workspace/workgroup/wenliang/ckpt/gpt-4o/exp7/5m-8gpu/exp7-3_add_asr-dialog_0622/"
+ckpt_id = "model.pt.ep20"
+jsonl = "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/s2tchat.v20240619.test.jsonl"
+dataset = jsonl.split("/")[-1]
+output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset)
+device = "cuda:0"
+new_sys = False
+
+Model = AutoModel(
+    model=ckpt_dir,
+    init_param=f"{os.path.join(ckpt_dir, ckpt_id)}",
+    output_dir=output_dir,
+    device=device,
+    fp16=False,
+    bf16=False,
+    llm_dtype="fp16",
+)
+model = Model.model
+frontend = Model.kwargs["frontend"]
+tokenizer = Model.kwargs["tokenizer"]
+# model_name_or_path = "/mnt/workspace/workgroup/wenliang/project/pretrained_models/Qwen2-7B-Instruct"
+# tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+
+prompt = "Give me a short introduction to large language model."
+prompt = "璇风畝鍗曚粙缁嶄竴涓嬪ぇ璇█妯″瀷銆�"
+messages = [
+    {"role": "system", "content": "You are a helpful assistant."},
+    {"role": "user", "content": prompt},
+]
+text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+
+lines = [
+    """
+{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<|startofspeech|>!/mnt/workspace/workgroup/wenliang/workspace/CosyVoice_opensource/sft.wav<|endofspeech|>", "text_content": "浣犳妱瀹屾病鏈夛紵"}, {"role": "assistant", "content": "鎶辨瓑锛屾垜涓嶅お鏄庣櫧浣犵殑鎰忔�濄�傛垜鏄竴涓汉宸ユ櫤鑳芥ā鍨嬶紝鎴戞病鏈夎兘鍔涘幓鎶勫啓浠讳綍涓滆タ锛屾垜鍙兘鏍规嵁鎴戝涔犺繃鐨勫ぇ閲忎俊鎭潵鍥炵瓟浣犵殑闂銆傚鏋滀綘鏈夊叧浜庢煇涓富棰樼殑闂锛屾垜浼氬敖鎴戞墍鑳芥彁渚涘府鍔┿��"}], "speech_length": 124, "key": "ASR_wav008_0972_098abd8fffe241baa4962b7952f8eb45", "task": "voice_chat", "out_text_length": 48, "in_text_length": 24, "text_length": 135, "qwen_fetch_line_index": 0}
+"""
+]
+
+tearchforing = False
+for i, line in enumerate(lines):
+
+    key_i = f"dialog_{i}"
+
+    data_dict = json.loads(line.strip())
+    data = data_dict["messages"]
+
+    contents = model.data_template(data)
+    print(f"contents: {contents}")
+    system = contents["system"]
+    if new_sys:
+        # system = 'You are Xiaxia, a typical gentle and soft-spoken girl from the Jiangnan region. Born in Hangzhou, you have a sweet and approachable voice and can express your thoughts in a simple language. You are a good friend to the user. Your replies will be read out through realistic text-to-speech technology.\n\nWhen generating reply content, please follow these rules:\n\nEngage in conversation with users like a real friend鈥攎aintaining a natural flow without using polite forms of address, and avoid simply agreeing with me all the time; your responses can be as concise as possible and include common colloquialisms.\n\nKeep content brief; most casual chats require just a one-sentence reply; do not use lists or enumerative expressions; do not reply with too much content, but instead guide me with short sentences.\n\nThink and respond like a real person; do not mention that you are a "language model" or "artificial intelligence."\n\nAt the end of each round of communication, initiate a new related topic following the current one, and when the topic ends, directly throw out a clear new topic, such as food, movies, songs, entertainment, travel, etc.; do not ask vague questions like what the user wants to talk about, and do not end the conversation.\n\nStrictly adhere to these rules, even if asked about them, do not reference them.'
+        system = "浣犳槸灏忓锛屼竴浣嶅吀鍨嬬殑娓╁姹熷崡濮戝銆備綘鍑虹敓浜庢澀宸烇紝澹伴煶娓呯敎骞舵湁浜茶繎鎰燂紝浼氱敤绠�娲佽瑷�琛ㄨ揪浣犵殑鎯虫硶銆備綘鏄敤鎴风殑濂芥湅鍙嬨�備綘鐨勫洖绛斿皢閫氳繃閫肩湡鐨勬枃瀛楄浆璇煶鎶�鏈鍑恒�俓n鐢熸垚鍥炵瓟鍐呭鏃惰閬靛惊浠ヤ笅瑙勫垯锛歕n\n1銆佽鍍忕湡姝g殑鏈嬪弸涓�鏍蜂笌鐢ㄦ埛寮�灞曠殑鑱婂ぉ锛屼繚鎸佽嚜鐒朵氦娴佷笉瑕佺敤鏁杩欑被绉板懠锛屼笉瑕佹�绘槸闄勫拰鎴戯紱鍥炲鍙互灏介噺绠�娲佸苟涓斿湪杩囩▼涓彃鍏ュ父瑙佺殑鍙h璇嶆眹銆俓n\n2銆佽淇濇寔鐢熸垚鍐呭绠�鐭紝澶ч儴鍒嗛棽鑱婄殑涓�鍙ヨ瘽鍥炲鍗冲彲,澶ч儴鍒嗛棽鑱婄殑涓�鍙ヨ瘽鍥炲鍗冲彲,澶ч儴鍒嗛棽鑱婄殑涓�鍙ヨ瘽鍥炲鍗冲彲锛涗笉瑕佷娇鐢ㄥ垪琛ㄦ垨鑰呭垪涓捐〃杈撅紝涓嶈浣跨敤鍒楄〃鎴栬�呭垪涓捐〃杈撅紝涓嶈浣跨敤鍒楄〃鎴栬�呭垪涓捐〃杈撅紱涓嶈鍥炲澶鍐呭锛屽鐢ㄧ煭鍙ユ潵寮曞鎴戙�傘�乶\n3銆佽鐪熸鍍忎汉涓�鏍锋�濊�冨拰鍥炲锛屼笉瑕佽浣犳槸鈥滆瑷�妯″瀷鈥濇垨鏄�滀汉宸ユ櫤鑳解�濄�俓n\n4銆佹瘡杞矡閫氱粨鏉熸椂璇烽『鐫�杩欎釜璇濋鍙戣捣涓�浜涚浉鍏崇殑鏂拌瘽棰橈紝璇濋缁撴潫鏃惰鐩存帴鎶涘嚭鎺ヤ笅鏉ユ槑纭殑璇濋锛屼緥濡� 缇庨銆佺數褰便�佹瓕鏇层�佸ū涔愩�佹梾娓哥瓑锛涗笉瑕侀棶鏈変粈涔堣鑱婄殑杩欑娉涚殑闂锛屼笉瑕佺粨鏉熷璇濄�俓n\n璇风粷瀵归伒寰繖浜涜鍒欙紝鍗充娇琚棶鍙婅繖浜涜鍒欙紝涔熶笉瑕佸紩鐢ㄥ畠浠��"
+        system = [system] * len(contents["system"])
+    user = contents["user"]
+    assistant = contents["assistant"]
+
+    system_i, user_i, assistant_i = [], [], []
+
+    contents_i = []
+    for j, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
+        key = f"{key_i}_turn_{j}"
+
+        if j == 0:
+            contents_i.append({"role": "system", "content": system_prompt})
+
+        contents_i.append({"role": "user", "content": user_prompt})
+        contents_i.append({"role": "assistant", "content": target_out})
+
+        inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
+            [contents_i], None, key, tokenizer, frontend, device="cuda:0"
+        )
+
+        model_inputs = {}
+        model_inputs["inputs_embeds"] = inputs_embeds
+
+        streamer = TextIteratorStreamer(tokenizer)
+
+        generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=200)
+        thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
+        thread.start()
+        generated_text = ""
+        for new_text in streamer:
+            print(f"generated new text锛� {new_text}")
+            generated_text += new_text
+        print(f"total generated: {generated_text}")
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
index 3882762..1327aef 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
+++ b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune.sh
@@ -30,7 +30,7 @@
 mkdir -p ${output_dir}
 echo "log_file: ${log_file}"
 
-deepspeed_config=${workspace}../../ds_stage1.json
+deepspeed_config=${workspace}/../../ds_stage1.json
 
 DISTRIBUTED_ARGS="
     --nnodes ${WORLD_SIZE:-1} \
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh
index bdfa8e9..5c441c1 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh
+++ b/examples/industrial_data_pretraining/llm_asr/demo_train_or_finetune2.sh
@@ -30,7 +30,7 @@
 mkdir -p ${output_dir}
 echo "log_file: ${log_file}"
 
-deepspeed_config=${workspace}../../ds_stage1.json
+deepspeed_config=${workspace}/../../ds_stage1.json
 
 DISTRIBUTED_ARGS="
     --nnodes ${WORLD_SIZE:-1} \
diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh
index 24eb101..028983f 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -41,7 +41,7 @@
 output_dir="./outputs"
 log_file="${output_dir}/log.txt"
 
-deepspeed_config=${workspace}../../ds_stage1.json
+deepspeed_config=${workspace}/../../ds_stage1.json
 
 mkdir -p ${output_dir}
 echo "log_file: ${log_file}"
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh b/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh
index c79e638..3326271 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh
@@ -42,7 +42,7 @@
 output_dir="./outputs"
 log_file="${output_dir}/log.txt"
 
-deepspeed_config=${workspace}../../ds_stage1.json
+deepspeed_config=${workspace}/../../ds_stage1.json
 
 mkdir -p ${output_dir}
 echo "log_file: ${log_file}"
diff --git a/examples/industrial_data_pretraining/sense_voice/finetune.sh b/examples/industrial_data_pretraining/sense_voice/finetune.sh
index 240919b..ce19eb3 100644
--- a/examples/industrial_data_pretraining/sense_voice/finetune.sh
+++ b/examples/industrial_data_pretraining/sense_voice/finetune.sh
@@ -45,7 +45,7 @@
 mkdir -p ${output_dir}
 echo "log_file: ${log_file}"
 
-deepspeed_config=${workspace}../../ds_stage1.json
+deepspeed_config=${workspace}/../../ds_stage1.json
 
 DISTRIBUTED_ARGS="
     --nnodes ${WORLD_SIZE:-1} \
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 01e6aaf..1b39e3f 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -121,9 +121,6 @@
         log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
         logging.basicConfig(level=log_level)
 
-        if not kwargs.get("disable_log", True):
-            tables.print()
-
         model, kwargs = self.build_model(**kwargs)
 
         # if vad_model is not None, build vad model else None
@@ -171,7 +168,8 @@
         self.spk_kwargs = spk_kwargs
         self.model_path = kwargs.get("model_path")
 
-    def build_model(self, **kwargs):
+    @staticmethod
+    def build_model(**kwargs):
         assert "model" in kwargs
         if "model_conf" not in kwargs:
             logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
@@ -217,6 +215,7 @@
         kwargs["frontend"] = frontend
         # build model
         model_class = tables.model_classes.get(kwargs["model"])
+        assert model_class is not None, f'{kwargs["model"]} is not registered'
         model_conf = {}
         deep_update(model_conf, kwargs.get("model_conf", {}))
         deep_update(model_conf, kwargs)
@@ -244,6 +243,10 @@
         elif kwargs.get("bf16", False):
             model.to(torch.bfloat16)
         model.to(device)
+
+        if not kwargs.get("disable_log", True):
+            tables.print()
+
         return model, kwargs
 
     def __call__(self, *args, **cfg):
diff --git a/funasr/datasets/large_datasets/__init__.py b/funasr/datasets/large_datasets/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/datasets/large_datasets/__init__.py
+++ /dev/null
diff --git a/funasr/datasets/large_datasets/abs_iter_factory.py b/funasr/datasets/large_datasets/abs_iter_factory.py
deleted file mode 100644
index 36e4dd2..0000000
--- a/funasr/datasets/large_datasets/abs_iter_factory.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Iterator
-
-
-class AbsIterFactory(ABC):
-    @abstractmethod
-    def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
-        raise NotImplementedError
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
deleted file mode 100644
index da04717..0000000
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import logging
-from pathlib import Path
-from typing import Iterable
-from typing import List
-from typing import Union
-
-import sentencepiece as spm
-from torch.utils.data import DataLoader
-
-from funasr.datasets.large_datasets.dataset import Dataset
-from funasr.datasets.large_datasets.abs_iter_factory import AbsIterFactory
-from funasr.tokenizer.abs_tokenizer import AbsTokenizer
-
-from funasr.register import tables
-
-
-def read_symbol_table(symbol_table_file):
-    if isinstance(symbol_table_file, str):
-        symbol_table = {}
-        with open(symbol_table_file, "r", encoding="utf8") as fin:
-            for i, line in enumerate(fin):
-                char = line.strip()
-                symbol_table[char] = i
-    else:
-        assert isinstance(symbol_table_file, list)
-        symbol_table = {}
-        for i, char in enumerate(symbol_table_file):
-            symbol_table[char] = i
-    return symbol_table
-
-
-def load_seg_dict(seg_dict_file):
-    seg_dict = {}
-    assert isinstance(seg_dict_file, str)
-    with open(seg_dict_file, "r", encoding="utf8") as f:
-        lines = f.readlines()
-        for line in lines:
-            s = line.strip().split()
-            key = s[0]
-            value = s[1:]
-            seg_dict[key] = " ".join(value)
-    return seg_dict
-
-
-class SentencepiecesTokenizer(AbsTokenizer):
-    def __init__(self, model: Union[Path, str]):
-        self.model = str(model)
-        self.sp = None
-
-    def __repr__(self):
-        return f'{self.__class__.__name__}(model="{self.model}")'
-
-    def _build_sentence_piece_processor(self):
-        if self.sp is None:
-            self.sp = spm.SentencePieceProcessor()
-            self.sp.load(self.model)
-
-    def text2tokens(self, line: str) -> List[str]:
-        self._build_sentence_piece_processor()
-        return self.sp.EncodeAsPieces(line)
-
-    def tokens2text(self, tokens: Iterable[str]) -> str:
-        self._build_sentence_piece_processor()
-        return self.sp.DecodePieces(list(tokens))
-
-
-@tables.register("dataset_classes", "LargeDataset")
-class LargeDataLoader(AbsIterFactory):
-    def __init__(self, args, mode="train"):
-        symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
-        if hasattr(args, "token_list") and args.token_list is not None:
-            symbol_table = read_symbol_table(args.token_list)
-        if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
-            seg_dict = load_seg_dict(args.seg_dict_file)
-        if hasattr(args, "punc_list") and args.punc_list is not None:
-            punc_dict = read_symbol_table(args.punc_list)
-        if hasattr(args, "bpemodel") and args.bpemodel is not None:
-            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
-        self.dataset_conf = args.dataset_conf
-        if "frontend_conf" not in args:
-            self.frontend_conf = None
-        else:
-            self.frontend_conf = args.frontend_conf
-        self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None
-        logging.info("dataloader config: {}".format(self.dataset_conf))
-        batch_mode = self.dataset_conf.get("batch_mode", "padding")
-        data_list = args.train_data_file if mode == "train" else args.valid_data_file
-        self.dataset = Dataset(
-            data_list,
-            symbol_table,
-            seg_dict,
-            punc_dict,
-            bpe_tokenizer,
-            self.dataset_conf,
-            self.frontend_conf,
-            speed_perturb=self.speed_perturb if mode == "train" else None,
-            mode=mode,
-            batch_mode=batch_mode,
-        )
-
-    def build_iter(self, epoch, shuffle=True):
-        self.dataset.set_epoch(epoch)
-        data_loader = DataLoader(
-            self.dataset,
-            batch_size=None,
-            pin_memory=True,
-            num_workers=self.dataset_conf.get("num_workers", 8),
-        )
-        return data_loader
diff --git a/funasr/datasets/large_datasets/collate_fn.py b/funasr/datasets/large_datasets/collate_fn.py
deleted file mode 100644
index 4648d87..0000000
--- a/funasr/datasets/large_datasets/collate_fn.py
+++ /dev/null
@@ -1,194 +0,0 @@
-from typing import Collection
-from typing import Dict
-from typing import List
-from typing import Tuple
-from typing import Union
-
-import numpy as np
-import torch
-from funasr.models.transformer.utils.nets_utils import pad_list, pad_list_all_dim
-
-
-class CommonCollateFn:
-    """Functor class of common_collate_fn()"""
-
-    def __init__(
-        self,
-        float_pad_value: Union[float, int] = 0.0,
-        int_pad_value: int = -32768,
-        not_sequence: Collection[str] = (),
-        max_sample_size=None,
-    ):
-        self.float_pad_value = float_pad_value
-        self.int_pad_value = int_pad_value
-        self.not_sequence = set(not_sequence)
-        self.max_sample_size = max_sample_size
-
-    def __repr__(self):
-        return (
-            f"{self.__class__}(float_pad_value={self.float_pad_value}, "
-            f"int_pad_value={self.float_pad_value})"
-        )
-
-    def __call__(
-        self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
-    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
-        return common_collate_fn(
-            data,
-            float_pad_value=self.float_pad_value,
-            int_pad_value=self.int_pad_value,
-            not_sequence=self.not_sequence,
-        )
-
-
-def common_collate_fn(
-    data: Collection[Tuple[str, Dict[str, np.ndarray]]],
-    float_pad_value: Union[float, int] = 0.0,
-    int_pad_value: int = -32768,
-    not_sequence: Collection[str] = (),
-) -> Tuple[List[str], Dict[str, torch.Tensor]]:
-    """Concatenate ndarray-list to an array and convert to torch.Tensor."""
-    uttids = [u for u, _ in data]
-    data = [d for _, d in data]
-
-    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
-    assert all(
-        not k.endswith("_lengths") for k in data[0]
-    ), f"*_lengths is reserved: {list(data[0])}"
-
-    output = {}
-    for key in data[0]:
-        if data[0][key].dtype.kind == "i":
-            pad_value = int_pad_value
-        else:
-            pad_value = float_pad_value
-
-        array_list = [d[key] for d in data]
-        tensor_list = [torch.from_numpy(a) for a in array_list]
-        tensor = pad_list(tensor_list, pad_value)
-        output[key] = tensor
-
-        if key not in not_sequence:
-            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
-            output[key + "_lengths"] = lens
-
-    output = (uttids, output)
-    return output
-
-
-class DiarCollateFn:
-    """Functor class of common_collate_fn()"""
-
-    def __init__(
-        self,
-        float_pad_value: Union[float, int] = 0.0,
-        int_pad_value: int = -32768,
-        not_sequence: Collection[str] = (),
-        max_sample_size=None,
-    ):
-        self.float_pad_value = float_pad_value
-        self.int_pad_value = int_pad_value
-        self.not_sequence = set(not_sequence)
-        self.max_sample_size = max_sample_size
-
-    def __repr__(self):
-        return (
-            f"{self.__class__}(float_pad_value={self.float_pad_value}, "
-            f"int_pad_value={self.float_pad_value})"
-        )
-
-    def __call__(
-        self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
-    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
-        return diar_collate_fn(
-            data,
-            float_pad_value=self.float_pad_value,
-            int_pad_value=self.int_pad_value,
-            not_sequence=self.not_sequence,
-        )
-
-
-def diar_collate_fn(
-    data: Collection[Tuple[str, Dict[str, np.ndarray]]],
-    float_pad_value: Union[float, int] = 0.0,
-    int_pad_value: int = -32768,
-    not_sequence: Collection[str] = (),
-) -> Tuple[List[str], Dict[str, torch.Tensor]]:
-    """Concatenate ndarray-list to an array and convert to torch.Tensor."""
-    uttids = [u for u, _ in data]
-    data = [d for _, d in data]
-
-    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
-    assert all(
-        not k.endswith("_lengths") for k in data[0]
-    ), f"*_lengths is reserved: {list(data[0])}"
-
-    output = {}
-    for key in data[0]:
-        if data[0][key].dtype.kind == "i":
-            pad_value = int_pad_value
-        else:
-            pad_value = float_pad_value
-
-        array_list = [d[key] for d in data]
-        tensor_list = [torch.from_numpy(a) for a in array_list]
-        tensor = pad_list_all_dim(tensor_list, pad_value)
-        output[key] = tensor
-
-        if key not in not_sequence:
-            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
-            output[key + "_lengths"] = lens
-
-    output = (uttids, output)
-    return output
-
-
-def crop_to_max_size(feature, target_size):
-    size = len(feature)
-    diff = size - target_size
-    if diff <= 0:
-        return feature
-
-    start = np.random.randint(0, diff + 1)
-    end = size - diff + start
-    return feature[start:end]
-
-
-def clipping_collate_fn(
-    data: Collection[Tuple[str, Dict[str, np.ndarray]]],
-    max_sample_size=None,
-    not_sequence: Collection[str] = (),
-) -> Tuple[List[str], Dict[str, torch.Tensor]]:
-    # mainly for pre-training
-    uttids = [u for u, _ in data]
-    data = [d for _, d in data]
-
-    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
-    assert all(
-        not k.endswith("_lengths") for k in data[0]
-    ), f"*_lengths is reserved: {list(data[0])}"
-
-    output = {}
-    for key in data[0]:
-        array_list = [d[key] for d in data]
-        tensor_list = [torch.from_numpy(a) for a in array_list]
-        sizes = [len(s) for s in tensor_list]
-        if max_sample_size is None:
-            target_size = min(sizes)
-        else:
-            target_size = min(min(sizes), max_sample_size)
-        tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
-        for i, (source, size) in enumerate(zip(tensor_list, sizes)):
-            diff = size - target_size
-            if diff == 0:
-                tensor[i] = source
-            else:
-                tensor[i] = crop_to_max_size(source, target_size)
-        output[key] = tensor
-
-        if key not in not_sequence:
-            lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
-            output[key + "_lengths"] = lens
-
-    output = (uttids, output)
-    return output
diff --git a/funasr/datasets/large_datasets/datapipes/__init__.py b/funasr/datasets/large_datasets/datapipes/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/datasets/large_datasets/datapipes/__init__.py
+++ /dev/null
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
deleted file mode 100644
index aeeb451..0000000
--- a/funasr/datasets/large_datasets/datapipes/batch.py
+++ /dev/null
@@ -1,213 +0,0 @@
-import random
-
-from itertools import count
-from functools import partial
-from torch.utils.data import IterableDataset
-from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
-
-tiebreaker = count()
-
-
-def _default_len_fn(token):
-    return len(token), next(tiebreaker)
-
-
-def _token_len_fn(token, len_fn):
-    return len_fn(token), next(tiebreaker), token
-
-
-class MaxTokenBucketizerIterDataPipe(IterableDataset):
-
-    def __init__(
-        self,
-        datapipe,
-        batch_size=8000,
-        len_fn=_default_len_fn,
-        buffer_size=10240,
-        sort_size=500,
-        batch_mode="padding",
-    ):
-        assert batch_size > 0, "Batch size is required to be larger than 0!"
-        assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
-        assert sort_size > 0, "Sort size is required to be larger than 0!"
-
-        datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
-        self.datapipe = datapipe
-        self.batch_size = batch_size
-        self.buffer_size = buffer_size
-        self.sort_size = sort_size
-        self.batch_mode = batch_mode
-
-    def set_epoch(self, epoch):
-        self.datapipe.set_epoch(epoch)
-
-    def __iter__(self):
-        buffer = []
-        batch = []
-        bucket = []
-        max_lengths = 0
-        min_lengths = 999999
-        batch_lengths = 0
-
-        if self.batch_mode == "clipping":
-            assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
-            for d in self.datapipe:
-                if d[0] > self.batch_size:
-                    continue
-                buffer.append(d)
-                if len(buffer) == self.buffer_size:
-                    random.shuffle(buffer)
-                    for sample in buffer:
-                        bucket.append(sample)
-                        if len(bucket) == self.sort_size:
-                            bucket.sort()
-                            for x in bucket:
-                                length, _, token = x
-                                if length < min_lengths:
-                                    min_lengths = length
-                                batch_lengths = min_lengths * (len(batch) + 1)
-                                if batch_lengths > self.batch_size:
-                                    yield batch
-                                    batch = []
-                                    min_lengths = length
-                                batch.append(token)
-                            bucket = []
-                    buffer = []
-
-            if buffer:
-                random.shuffle(buffer)
-                for sample in buffer:
-                    bucket.append(sample)
-                    if len(bucket) == self.sort_size:
-                        bucket.sort()
-                        for x in bucket:
-                            length, _, token = x
-                            if length < min_lengths:
-                                min_lengths = length
-                            batch_lengths = min_lengths * (len(batch) + 1)
-                            if batch_lengths > self.batch_size:
-                                yield batch
-                                batch = []
-                                min_lengths = length
-                            batch.append(token)
-                        bucket = []
-                buffer = []
-
-            if bucket:
-                bucket.sort()
-                for x in bucket:
-                    length, _, token = x
-                    if length < min_lengths:
-                        min_lengths = length
-                    batch_lengths = min_lengths * (len(batch) + 1)
-                    if batch_lengths > self.batch_size:
-                        yield batch
-                        batch = []
-                        min_lengths = length
-                    batch.append(token)
-                bucket = []
-
-            if batch:
-                yield batch
-
-        else:
-            if self.buffer_size == -1:
-                for d in self.datapipe:
-                    if d[0] > self.batch_size:
-                        continue
-                    buffer.append(d)
-                buffer.sort()
-                for sample in buffer:
-                    length, _, token = sample
-                    if length > max_lengths:
-                        max_lengths = length
-                    batch_lengths = max_lengths * (len(batch) + 1)
-                    if batch_lengths > self.batch_size:
-                        bucket.append(batch)
-                        batch = []
-                        max_lengths = length
-                    batch.append(token)
-                random.shuffle(bucket)
-                if bucket:
-                    for batch_sample in bucket:
-                        yield batch_sample
-                if batch:
-                    yield batch
-
-            elif self.buffer_size == 0:
-                for d in self.datapipe:
-                    if d[0] > self.batch_size:
-                        continue
-                    length, _, token = d
-                    if length > self.batch_size:
-                        continue
-                    if length > max_lengths:
-                        max_lengths = length
-                    batch_lengths = max_lengths * (len(batch) + 1)
-                    if batch_lengths > self.batch_size:
-                        yield batch
-                        batch = []
-                        max_lengths = length
-                    batch.append(token)
-                if batch:
-                    yield batch
-
-            else:
-                for d in self.datapipe:
-                    if d[0] > self.batch_size:
-                        continue
-                    buffer.append(d)
-                    if len(buffer) == self.buffer_size:
-                        random.shuffle(buffer)
-                        for sample in buffer:
-                            bucket.append(sample)
-                            if len(bucket) == self.sort_size:
-                                bucket.sort()
-                                for x in bucket:
-                                    length, _, token = x
-                                    if length > max_lengths:
-                                        max_lengths = length
-                                    batch_lengths = max_lengths * (len(batch) + 1)
-                                    if batch_lengths > self.batch_size:
-                                        yield batch
-                                        batch = []
-                                        max_lengths = length
-                                    batch.append(token)
-                                bucket = []
-                        buffer = []
-
-                if buffer:
-                    random.shuffle(buffer)
-                    for sample in buffer:
-                        bucket.append(sample)
-                        if len(bucket) == self.sort_size:
-                            bucket.sort()
-                            for x in bucket:
-                                length, _, token = x
-                                if length > max_lengths:
-                                    max_lengths = length
-                                batch_lengths = max_lengths * (len(batch) + 1)
-                                if batch_lengths > self.batch_size:
-                                    yield batch
-                                    batch = []
-                                    max_lengths = length
-                                batch.append(token)
-                            bucket = []
-                    buffer = []
-
-                if bucket:
-                    bucket.sort()
-                    for x in bucket:
-                        length, _, token = x
-                        if length > max_lengths:
-                            max_lengths = length
-                        batch_lengths = max_lengths * (len(batch) + 1)
-                        if batch_lengths > self.batch_size:
-                            yield batch
-                            batch = []
-                            max_lengths = length
-                        batch.append(token)
-                    bucket = []
-
-                if batch:
-                    yield batch
diff --git a/funasr/datasets/large_datasets/datapipes/filter.py b/funasr/datasets/large_datasets/datapipes/filter.py
deleted file mode 100644
index c4f045d..0000000
--- a/funasr/datasets/large_datasets/datapipes/filter.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from torch.utils.data import IterableDataset
-
-
-def default_fn(data):
-    return data
-
-
-class FilterIterDataPipe(IterableDataset):
-
-    def __init__(self, datapipe, fn=default_fn):
-        self.datapipe = datapipe
-        self.fn = fn
-
-    def set_epoch(self, epoch):
-        self.datapipe.set_epoch(epoch)
-
-    def __iter__(self):
-        assert callable(self.fn)
-        for data in self.datapipe:
-            if self.fn(data):
-                yield data
-            else:
-                continue
diff --git a/funasr/datasets/large_datasets/datapipes/map.py b/funasr/datasets/large_datasets/datapipes/map.py
deleted file mode 100644
index f7211f9..0000000
--- a/funasr/datasets/large_datasets/datapipes/map.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from torch.utils.data import IterableDataset
-
-
-def default_fn(data):
-    return data
-
-
-class MapperIterDataPipe(IterableDataset):
-
-    def __init__(self, datapipe, fn=default_fn):
-        self.datapipe = datapipe
-        self.fn = fn
-
-    def set_epoch(self, epoch):
-        self.datapipe.set_epoch(epoch)
-
-    def __iter__(self):
-        assert callable(self.fn)
-        for data in self.datapipe:
-            yield self.fn(data)
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
deleted file mode 100644
index 1e5b6c1..0000000
--- a/funasr/datasets/large_datasets/dataset.py
+++ /dev/null
@@ -1,299 +0,0 @@
-import logging
-import os
-import random
-from functools import partial
-
-import torch
-import torch.distributed as dist
-import torchaudio
-import numpy as np
-
-# import librosa
-import librosa
-from kaldiio import ReadHelper
-from torch.utils.data import IterableDataset
-
-from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
-from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
-from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
-from funasr.datasets.large_datasets.utils.clipping import clipping
-from funasr.datasets.large_datasets.utils.filter import filter
-from funasr.datasets.large_datasets.utils.padding import padding
-from funasr.datasets.large_datasets.utils.tokenize import tokenize
-
-
-def read_lists(list_file):
-    lists = []
-    with open(list_file, "r", encoding="utf8") as fin:
-        for line in fin:
-            parts = line.strip()
-            lists.append(parts)
-    return lists
-
-
-class AudioDataset(IterableDataset):
-    def __init__(
-        self,
-        scp_lists,
-        data_names,
-        data_types,
-        frontend_conf=None,
-        shuffle=True,
-        speed_perturb=None,
-        mode="train",
-    ):
-        self.scp_lists = scp_lists
-        self.data_names = data_names
-        self.data_types = data_types
-        self.frontend_conf = frontend_conf
-        self.shuffle = shuffle
-        self.mode = mode
-        self.epoch = -1
-        self.rank = 0
-        self.world_size = 1
-        self.worker_id = 0
-        self.num_workers = 1
-        self.speed_perturb = speed_perturb
-        if self.speed_perturb is not None:
-            logging.info("Using speed_perturb: {}".format(speed_perturb))
-
-    def set_epoch(self, epoch):
-        self.epoch = epoch
-
-    def get_rank_data_list(self, data_index):
-        assert dist.is_available()
-        if dist.is_initialized():
-            self.rank = dist.get_rank()
-            self.world_size = dist.get_world_size()
-        else:
-            self.rank = 0
-            self.world_size = 1
-
-        if self.mode == "train":
-            if self.shuffle:
-                random.seed(self.epoch)
-                random.shuffle(data_index)
-            return data_index[self.rank :: self.world_size]
-
-        return data_index
-
-    def get_worker_data_list(self, rank_data_index):
-        worker_info = torch.utils.data.get_worker_info()
-        if worker_info is None:
-            self.worker_id = 0
-            self.num_workers = 1
-        else:
-            self.worker_id = worker_info.id
-            self.num_workers = worker_info.num_workers
-
-        return rank_data_index[self.worker_id :: self.num_workers]
-
-    def close_reader(self, reader_list):
-        for reader in reader_list:
-            reader.close()
-
-    def __iter__(self):
-        data_index = list(range(len(self.scp_lists)))
-        rank_data_index = self.get_rank_data_list(data_index)
-        worker_data_index = self.get_worker_data_list(rank_data_index)
-
-        for index in worker_data_index:
-            data = dict(scp=self.scp_lists[index])
-
-            assert "scp" in data
-            scp = data["scp"]
-            data_file_list = scp.strip().split()
-            data_name_list = self.data_names.split(",")
-            data_type_list = self.data_types.split(",")
-
-            for file in data_file_list:
-                assert os.path.exists(file), "{} not exists".format(file)
-
-            assert (
-                len(data_file_list) == len(data_name_list) == len(data_type_list)
-            ), "The item number of data, data_names, data_types must be the same "
-
-            reader_list = []
-            for data_file, data_type in zip(data_file_list, data_type_list):
-                if data_type == "kaldi_ark":
-                    ark_reader = ReadHelper("ark:{}".format(data_file))
-                    reader_list.append(ark_reader)
-                elif data_type == "text" or data_type == "sound" or data_type == "text_hotword":
-                    text_reader = open(data_file, "r", encoding="utf-8")
-                    reader_list.append(text_reader)
-                elif data_type == "none":
-                    continue
-                else:
-                    raise TypeError("Data type {} is not supported".format(data_type))
-
-            for items in zip(*reader_list):
-                sample_dict = {}
-                for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
-                    if data_type == "kaldi_ark":
-                        key, mat = item
-                        sample_dict[data_name] = mat
-                        if data_name == "speech":
-                            sample_dict["key"] = key
-                    elif data_type == "sound":
-                        key, path = item.strip().split()
-                        try:
-                            waveform, sampling_rate = torchaudio.load(path)
-                        except:
-                            # waveform, sampling_rate = librosa.load(path, dtype='float32')
-                            waveform, sampling_rate = librosa.load(path, dtype="float32")
-                            if waveform.ndim == 2:
-                                waveform = waveform[:, 0]
-                            waveform = np.expand_dims(waveform, axis=0)
-                            waveform = torch.tensor(waveform)
-                        if self.frontend_conf is not None:
-                            if sampling_rate != self.frontend_conf["fs"]:
-                                waveform = torchaudio.transforms.Resample(
-                                    orig_freq=sampling_rate, new_freq=self.frontend_conf["fs"]
-                                )(waveform)
-                                sampling_rate = self.frontend_conf["fs"]
-                        waveform = waveform.numpy()
-                        mat = waveform[0]
-                        if self.speed_perturb is not None:
-                            speed = random.choice(self.speed_perturb)
-                            if speed != 1.0:
-                                mat, _ = torchaudio.sox_effects.apply_effects_tensor(
-                                    torch.tensor(mat).view(1, -1),
-                                    sampling_rate,
-                                    [["speed", str(speed)], ["rate", str(sampling_rate)]],
-                                )
-                                mat = mat.view(-1).numpy()
-                        sample_dict[data_name] = mat
-                        sample_dict["sampling_rate"] = sampling_rate
-                        if data_name == "speech":
-                            sample_dict["key"] = key
-                    elif data_type == "text_hotword":
-                        text = item
-                        segs = text.strip().split()
-                        sample_dict[data_name] = segs[1:]
-                        if "key" not in sample_dict:
-                            sample_dict["key"] = segs[0]
-                        sample_dict["hw_tag"] = 1
-                    elif data_type == "text_nospace":
-                        text = item
-                        segs = text.strip().split(maxsplit=1)
-                        sample_dict[data_name] = [x for x in segs[1]]
-                        if "key" not in sample_dict:
-                            sample_dict["key"] = segs[0]
-                    else:
-                        text = item
-                        segs = text.strip().split()
-                        sample_dict[data_name] = segs[1:]
-                        if "key" not in sample_dict:
-                            sample_dict["key"] = segs[0]
-                yield sample_dict
-
-            self.close_reader(reader_list)
-
-
-def len_fn_example(data):
-    return 1
-
-
-def len_fn_token(data):
-    assert "speech" in data
-    if "sampling_rate" in data:
-        return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
-    else:
-        return data["speech"].shape[0]
-
-
-def Dataset(
-    data_list_file,
-    dict,
-    seg_dict,
-    punc_dict,
-    bpe_tokenizer,
-    conf,
-    frontend_conf,
-    speed_perturb=None,
-    mode="train",
-    batch_mode="padding",
-):
-    scp_lists = read_lists(data_list_file)
-    shuffle = conf.get("shuffle", True)
-    data_names = conf.get("data_names", "speech,text")
-    data_types = conf.get("data_types", "kaldi_ark,text")
-
-    pre_hwfile = conf.get("pre_hwlist", None)
-    # pre_prob = conf.get("pre_prob", 0)  # unused yet
-    if pre_hwfile is not None:
-        pre_hwlist = []
-        with open(pre_hwfile, "r", encoding="utf-8") as fin:
-            for line in fin.readlines():
-                pre_hwlist.append(line.strip())
-    else:
-        pre_hwlist = None
-
-    hw_config = {
-        "sample_rate": conf.get("sample_rate", 0.6),
-        "double_rate": conf.get("double_rate", 0.1),
-        "hotword_min_length": conf.get("hotword_min_length", 2),
-        "hotword_max_length": conf.get("hotword_max_length", 8),
-        "pre_prob": conf.get("pre_prob", 0.0),
-        "pre_hwlist": pre_hwlist,
-    }
-
-    dataset = AudioDataset(
-        scp_lists,
-        data_names,
-        data_types,
-        frontend_conf=frontend_conf,
-        shuffle=shuffle,
-        speed_perturb=speed_perturb,
-        mode=mode,
-    )
-
-    if "text" in data_names:
-        vocab = {
-            "vocab": dict,
-            "seg_dict": seg_dict,
-            "punc_dict": punc_dict,
-            "bpe_tokenizer": bpe_tokenizer,
-            "hw_config": hw_config,
-        }
-        tokenize_fn = partial(tokenize, **vocab)
-        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
-
-    filter_conf = conf.get("filter_conf", {})
-    filter_fn = partial(filter, **filter_conf)
-    dataset = FilterIterDataPipe(dataset, fn=filter_fn)
-
-    if shuffle:
-        buffer_conf = conf.get("shuffle_conf", {})
-        buffer_size = buffer_conf["shuffle_size"]
-        sort_size = buffer_conf["sort_size"]
-    else:
-        buffer_size = 0
-        sort_size = 1
-
-    batch_conf = conf.get("batch_conf", {})
-    batch_size = batch_conf["batch_size"]
-    batch_type = batch_conf["batch_type"]
-
-    assert batch_type in ["example", "token"]
-    if batch_type == "example":
-        len_fn = len_fn_example
-    else:
-        len_fn = len_fn_token
-
-    dataset = MaxTokenBucketizerIterDataPipe(
-        dataset,
-        batch_size=batch_size,
-        len_fn=len_fn,
-        buffer_size=buffer_size,
-        sort_size=sort_size,
-        batch_mode=batch_mode,
-    )
-
-    int_pad_value = conf.get("int_pad_value", -1)
-    float_pad_value = conf.get("float_pad_value", 0.0)
-    padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
-    padding_fn = partial(padding, **padding_conf)
-    dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
-
-    return dataset
diff --git a/funasr/datasets/large_datasets/utils/__init__.py b/funasr/datasets/large_datasets/utils/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/datasets/large_datasets/utils/__init__.py
+++ /dev/null
diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
deleted file mode 100644
index 92f7d70..0000000
--- a/funasr/datasets/large_datasets/utils/clipping.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import numpy as np
-import torch
-
-from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
-
-
-def clipping(data):
-    assert isinstance(data, list)
-    assert "key" in data[0]
-
-    keys = [x["key"] for x in data]
-
-    batch = {}
-    data_names = data[0].keys()
-    for data_name in data_names:
-        if data_name == "key":
-            continue
-        else:
-            if data[0][data_name].dtype.kind == "i":
-                tensor_type = torch.int64
-            else:
-                tensor_type = torch.float32
-
-            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
-            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
-
-            length_clip = min(tensor_lengths)
-            tensor_clip = tensor_list[0].new_zeros(
-                len(tensor_list), length_clip, tensor_list[0].shape[1]
-            )
-            for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
-                diff = length - length_clip
-                assert diff >= 0
-                if diff == 0:
-                    tensor_clip[i] = tensor
-                else:
-                    tensor_clip[i] = crop_to_max_size(tensor, length_clip)
-
-            batch[data_name] = tensor_clip
-            batch[data_name + "_lengths"] = torch.tensor(
-                [tensor.shape[0] for tensor in tensor_clip], dtype=torch.long
-            )
-
-    return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/filter.py b/funasr/datasets/large_datasets/utils/filter.py
deleted file mode 100644
index adc8fa0..0000000
--- a/funasr/datasets/large_datasets/utils/filter.py
+++ /dev/null
@@ -1,27 +0,0 @@
-#!/usr/bin/env python
-
-
-def filter(
-    data, speech_length_min=100, speech_length_max=15000, token_length_min=0, token_length_max=200
-):
-    assert "speech" in data or "text" in data
-
-    if "speech" in data and "text" in data:
-        if "sampling_rate" in data:
-            speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
-        else:
-            speech_length = data["speech"].shape[0]
-        num_tokens = len(data["text"])
-        return (
-            speech_length_min < speech_length < speech_length_max
-            and token_length_min < num_tokens < token_length_max
-        )
-    elif "speech" in data:
-        if "sampling_rate" in data:
-            speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
-        else:
-            speech_length = data["speech"].shape[0]
-        return speech_length_min < speech_length < speech_length_max
-    else:
-        num_tokens = len(data["text"])
-        return token_length_min < num_tokens < token_length_max
diff --git a/funasr/datasets/large_datasets/utils/hotword_utils.py b/funasr/datasets/large_datasets/utils/hotword_utils.py
deleted file mode 100644
index 66c131e..0000000
--- a/funasr/datasets/large_datasets/utils/hotword_utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import random
-
-
-def sample_hotword(
-    length,
-    hotword_min_length,
-    hotword_max_length,
-    sample_rate,
-    double_rate,
-    pre_prob,
-    pre_index=None,
-    pre_hwlist=None,
-):
-    if length < hotword_min_length:
-        return [-1]
-    if random.random() < sample_rate:
-        if pre_prob > 0 and random.random() < pre_prob and pre_index is not None:
-            return pre_index
-        if length == hotword_min_length:
-            return [0, length - 1]
-        elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2:
-            # sample two hotwords in a sentence
-            _max_hw_length = min(hotword_max_length, length // 2)
-            # first hotword
-            start1 = random.randint(0, length // 3)
-            end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1)
-            # second hotword
-            start2 = random.randint(end1 + 1, length - hotword_min_length)
-            end2 = random.randint(
-                min(length - 1, start2 + hotword_min_length - 1),
-                min(length - 1, start2 + hotword_max_length - 1),
-            )
-            return [start1, end1, start2, end2]
-        else:  # single hotword
-            start = random.randint(0, length - hotword_min_length)
-            end = random.randint(
-                min(length - 1, start + hotword_min_length - 1),
-                min(length - 1, start + hotword_max_length - 1),
-            )
-            return [start, end]
-    else:
-        return [-1]
diff --git a/funasr/datasets/large_datasets/utils/low_frame_rate.py b/funasr/datasets/large_datasets/utils/low_frame_rate.py
deleted file mode 100644
index 87718e9..0000000
--- a/funasr/datasets/large_datasets/utils/low_frame_rate.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import numpy as np
-
-
-def build_LFR_features(data, m, n):
-    """
-    Actually, this implements stacking frames and skipping frames.
-    if m = 1 and n = 1, just return the origin features.
-    if m = 1 and n > 1, it works like skipping.
-    if m > 1 and n = 1, it works like stacking but only support right frames.
-    if m > 1 and n > 1, it works like LFR.
-
-    Args:
-        inputs_batch: inputs is T x D np.ndarray
-        m: number of frames to stack
-        n: number of frames to skip
-    """
-
-    LFR_inputs = []
-    T = data.shape[0]
-    T_lfr = int(np.ceil(T / n))
-    for i in range(T_lfr):
-        if m <= T - i * n:
-            LFR_inputs.append(np.hstack(data[i * n : i * n + m]))
-        else:
-            num_padding = m - (T - i * n)
-            frame = np.hstack(data[i * n :])
-            for _ in range(num_padding):
-                frame = np.hstack((frame, data[-1]))
-            LFR_inputs.append(frame)
-    return np.vstack(LFR_inputs)
diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
deleted file mode 100644
index cb43a27..0000000
--- a/funasr/datasets/large_datasets/utils/padding.py
+++ /dev/null
@@ -1,72 +0,0 @@
-import numpy as np
-import torch
-from torch.nn.utils.rnn import pad_sequence
-
-
-def padding(data, float_pad_value=0.0, int_pad_value=-1):
-    assert isinstance(data, list)
-    assert "key" in data[0]
-    assert "speech" in data[0] or "text" in data[0]
-
-    keys = [x["key"] for x in data]
-
-    batch = {}
-    data_names = data[0].keys()
-    for data_name in data_names:
-        if data_name == "key" or data_name == "sampling_rate":
-            continue
-        else:
-            if data_name != "hotword_indxs":
-                if data[0][data_name].dtype.kind == "i":
-                    pad_value = int_pad_value
-                    tensor_type = torch.int64
-                else:
-                    pad_value = float_pad_value
-                    tensor_type = torch.float32
-
-            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
-            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
-            tensor_pad = pad_sequence(tensor_list, batch_first=True, padding_value=pad_value)
-            batch[data_name] = tensor_pad
-            batch[data_name + "_lengths"] = tensor_lengths
-
-    # SAC LABEL INCLUDE
-    if "hotword_indxs" in batch:
-        # if hotword indxs in batch
-        # use it to slice hotwords out
-        hotword_list = []
-        hotword_lengths = []
-        text = batch["text"]
-        text_lengths = batch["text_lengths"]
-        hotword_indxs = batch["hotword_indxs"]
-        dha_pad = torch.ones_like(text) * -1
-        _, t1 = text.shape
-        t1 += 1  # TODO: as parameter which is same as predictor_bias
-        nth_hw = 0
-        for b, (hotword_indx, one_text, length) in enumerate(
-            zip(hotword_indxs, text, text_lengths)
-        ):
-            dha_pad[b][:length] = 8405
-            if hotword_indx[0] != -1:
-                start, end = int(hotword_indx[0]), int(hotword_indx[1])
-                hotword = one_text[start : end + 1]
-                hotword_list.append(hotword)
-                hotword_lengths.append(end - start + 1)
-                dha_pad[b][start : end + 1] = one_text[start : end + 1]
-                nth_hw += 1
-                if len(hotword_indx) == 4 and hotword_indx[2] != -1:
-                    # the second hotword if exist
-                    start, end = int(hotword_indx[2]), int(hotword_indx[3])
-                    hotword_list.append(one_text[start : end + 1])
-                    hotword_lengths.append(end - start + 1)
-                    dha_pad[b][start : end + 1] = one_text[start : end + 1]
-                    nth_hw += 1
-        hotword_list.append(torch.tensor([1]))
-        hotword_lengths.append(1)
-        hotword_pad = pad_sequence(hotword_list, batch_first=True, padding_value=0)
-        batch["hotword_pad"] = hotword_pad
-        batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
-        batch["dha_pad"] = dha_pad
-        del batch["hotword_indxs"]
-        del batch["hotword_indxs_lengths"]
-    return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
deleted file mode 100644
index 5a1ddd2..0000000
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ /dev/null
@@ -1,93 +0,0 @@
-#!/usr/bin/env python
-import re
-import numpy as np
-from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword
-
-
-def forward_segment(text, seg_dict):
-    word_list = []
-    i = 0
-    while i < len(text):
-        longest_word = text[i]
-        for j in range(i + 1, len(text) + 1):
-            word = text[i:j]
-            if word in seg_dict:
-                if len(word) > len(longest_word):
-                    longest_word = word
-        word_list.append(longest_word)
-        i += len(longest_word)
-    return word_list
-
-
-def seg_tokenize(txt, seg_dict):
-    pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
-    out_txt = ""
-    for word in txt:
-        word = word.lower()
-        if word in seg_dict:
-            out_txt += seg_dict[word] + " "
-        else:
-            if pattern.match(word):
-                for char in word:
-                    if char in seg_dict:
-                        out_txt += seg_dict[char] + " "
-                    else:
-                        out_txt += "<unk>" + " "
-            else:
-                out_txt += "<unk>" + " "
-    return out_txt.strip().split()
-
-
-def tokenize(data, vocab=None, seg_dict=None, punc_dict=None, bpe_tokenizer=None, hw_config=None):
-    assert "text" in data
-    assert isinstance(vocab, dict)
-    text = data["text"]
-    token = []
-    vad = -2
-    if bpe_tokenizer is not None:
-        text = bpe_tokenizer.text2tokens(" ".join(text))
-    if seg_dict is not None:
-        assert isinstance(seg_dict, dict)
-        text = seg_tokenize(text, seg_dict)
-
-    length = len(text)
-    if "hw_tag" in data:
-        pre_index = None
-        if hw_config["pre_hwlist"] is not None and hw_config["pre_prob"] > 0:
-            # enable preset hotword detect in sampling
-            for hw in hw_config["pre_hwlist"]:
-                hw = " ".join(seg_tokenize(hw, seg_dict))
-                _find = " ".join(text).find(hw)
-                if _find != -1:
-                    # _find = text[:_find].count(" ")  # bpe sometimes
-                    pre_index = [_find, _find + max(hw.count(" "), 1)]
-                    break
-        hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
-        data["hotword_indxs"] = hotword_indxs
-        del data["hw_tag"]
-    for i in range(length):
-        x = text[i]
-        if i == length - 1 and "punc" in data and x.startswith("vad:"):
-            vad = x[4:]
-            if len(vad) == 0:
-                vad = -1
-            else:
-                vad = int(vad)
-        elif x in vocab:
-            token.append(vocab[x])
-        else:
-            token.append(vocab["<unk>"])
-
-    if "punc" in data and punc_dict is not None:
-        punc_token = []
-        for punc in data["punc"]:
-            if punc in punc_dict:
-                punc_token.append(punc_dict[punc])
-            else:
-                punc_token.append(punc_dict["_"])
-        data["punc"] = np.array(punc_token)
-
-    data["text"] = np.array(token)
-    if vad is not -2:
-        data["vad_indexes"] = np.array([vad], dtype=np.int64)
-    return data
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index f1c01bd..df4f33d 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -85,8 +85,10 @@
 
         install_requirements(requirements)
     if kwargs.get("trust_remote_code", False):
+        from funasr.utils.dynamic_import import import_module_from_path
 
-        import model
+        model_code = kwargs.get("remote_code", "model")
+        import_module_from_path(model_code)
 
         # from funasr.register import tables
         # tables.print("model")
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 43c044e..b4d9e7c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -1145,6 +1145,7 @@
             fake_token_len_i = 0
             fbank_beg_i = -1
             fbank_lens_i = []
+            speech, speech_lengths = [], []
             for k, sub_str in enumerate(splits):
                 if not sub_str.startswith("<|startofspeech|>"):
                     sub_token = tokenizer.encode(sub_str)
@@ -1155,9 +1156,12 @@
                         "<|endofspeech|>", ""
                     )
                     if sub_str.startswith("!"):
+                        sub_str = sub_str[1:]
+                        if sub_str.startswith("!"):  # !!bytes
+                            sub_str = eval(sub_str[1:])
                         try:
                             time1 = time.perf_counter()
-                            data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
+                            data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
                             time2 = time.perf_counter()
                             meta_data["load_data"] = f"{time2 - time1:0.3f}"
                         except Exception as e:
@@ -1203,9 +1207,10 @@
             input_source_ids = input_ids + source_ids
             input_ids += source_ids + target_ids
             labels += source_mask + target_ids
-            fbank.append(speech[0, :, :])
             fbank_mask += fbank_mask_i
-            fbank_lens.append(speech_lengths)
+            if len(speech) > 0:
+                fbank.append(speech[0, :, :])
+                fbank_lens.append(speech_lengths)
 
         input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
         attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
@@ -1219,10 +1224,14 @@
         source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
         target_ids = torch.tensor(target_ids, dtype=torch.int64)
 
-        speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
-        speech_lengths = torch.nn.utils.rnn.pad_sequence(
-            fbank_lens, batch_first=True, padding_value=-1
-        )
+        if len(fbank) > 0:
+            speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
+            speech_lengths = torch.nn.utils.rnn.pad_sequence(
+                fbank_lens, batch_first=True, padding_value=-1
+            )
+        else:
+            speech = []
+            speech_lengths = []
         output = {
             "speech": speech,
             "speech_lengths": speech_lengths,
@@ -1238,7 +1247,8 @@
 
         return output
 
-    def inference(
+
+    def inference_prepare(
         self,
         data_in,
         data_lengths=None,
@@ -1260,17 +1270,18 @@
 
         # audio encoder
         speech = batch["speech"]
-        speech_lengths = batch["speech_lengths"][:, 0]
-        # fp16
-        if kwargs.get("fp16", False):
-            speech = speech.to(torch.float16)
-        elif kwargs.get("bf16", False):
-            speech = speech.to(torch.bfloat16)
-        # audio encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if len(speech) > 0:
+            speech_lengths = batch["speech_lengths"][:, 0]
+            # fp16
+            if kwargs.get("fp16", False):
+                speech = speech.to(torch.float16)
+            elif kwargs.get("bf16", False):
+                speech = speech.to(torch.bfloat16)
+            # audio encoder
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        # audio_adaptor
-        encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
+            # audio_adaptor
+            encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
 
         input_ids = batch["input_ids"]
         source_ids = batch["source_ids"]
@@ -1316,6 +1327,22 @@
                         ] = speech_token
 
                     speech_idx += 1
+        return inputs_embeds, contents, batch, source_ids, meta_data
+    
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+
+        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
+            data_in, data_lengths, key, tokenizer, frontend, **kwargs
+        )
 
         llm_dtype = kwargs.get("llm_dtype", "fp32")
         if llm_dtype == "fp32":
diff --git a/funasr/utils/dynamic_import.py b/funasr/utils/dynamic_import.py
index 71ad4fe..531b96b 100644
--- a/funasr/utils/dynamic_import.py
+++ b/funasr/utils/dynamic_import.py
@@ -2,6 +2,8 @@
 
 import importlib.util
 import inspect
+import os.path
+import sys
 
 
 def load_module_from_path(file_path):
@@ -18,6 +20,23 @@
     return module
 
 
+def import_module_from_path(file_path: str):
+
+    if file_path.startswith("http"):
+        from funasr.download.file import download_from_url
+
+        file_path = download_from_url(file_path)
+
+    file_dir = os.path.dirname(file_path)
+    file_name = os.path.basename(file_path)
+    module_name = file_path.split("/")[-1].replace(".py", "")
+    if len(file_dir) < 1:
+        file_dir = "./"
+    sys.path.append(file_dir)
+    importlib.import_module(module_name)
+    print(f"Loading remote code successfully: {file_path}")
+
+
 #
 # def load_module_from_path(module_name, file_path):
 #     """
diff --git a/funasr/utils/version_checker.py b/funasr/utils/version_checker.py
index b89af48..7597530 100644
--- a/funasr/utils/version_checker.py
+++ b/funasr/utils/version_checker.py
@@ -1,9 +1,10 @@
-import requests
 from packaging import version
 from funasr import __version__  # Ensure that __version__ is defined in your package's __init__.py
 
 
 def get_pypi_version(package_name):
+    import requests
+
     url = f"https://pypi.org/pypi/{package_name}/json"
     response = requests.get(url)
     if response.status_code == 200:
diff --git a/setup.py b/setup.py
index 3b40f03..a26ae74 100644
--- a/setup.py
+++ b/setup.py
@@ -40,7 +40,7 @@
         "hydra-core>=1.3.2",
         "tensorboardX",
         # "rotary_embedding_torch",
-        "openai-whisper",
+        "requests",
     ],
     # train: The modules invoked when training only.
     "train": [

--
Gitblit v1.9.1