From a0f03bd2a87d97d47a1636bbe6f0855a43160331 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 15 五月 2024 19:48:50 +0800
Subject: [PATCH] Dev gzf deepspeed (#1732)

---
 funasr/datasets/audio_datasets/espnet_samplers.py             |    7 
 funasr/train_utils/trainer.py                                 |   50 +
 funasr/bin/train_ds.py                                        |  241 ++++++++
 README_zh.md                                                  |   28 
 examples/industrial_data_pretraining/sense_voice/demo.py      |    5 
 README.md                                                     |   10 
 funasr/datasets/sense_voice_datasets/datasets.py              |   24 
 funasr/models/sense_voice/model.py                            |  412 +++++++++++++++
 funasr/utils/misc.py                                          |   14 
 funasr/tokenizer/sentencepiece_tokenizer.py                   |    8 
 examples/industrial_data_pretraining/emotion2vec/demo.py      |    8 
 funasr/bin/train.py                                           |    3 
 funasr/train_utils/trainer_ds.py                              |  800 +++++++++++++++++++++++++++++
 examples/industrial_data_pretraining/sense_voice/demo_fsmn.py |    1 
 14 files changed, 1,553 insertions(+), 58 deletions(-)

diff --git a/README.md b/README.md
index ba23f3f..e02b3e2 100644
--- a/README.md
+++ b/README.md
@@ -28,6 +28,7 @@
 
 <a name="whats-new"></a>
 ## What's new:
+- 2024/05/15锛歟motion recognition models are new supported. [emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)锛孾emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary)锛孾emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary). currently supports the following categories: 0: angry 1: happy 2: neutral 3: sad 4: unknown.
 - 2024/05/15: Offline File Transcription Service 4.5, Offline File Transcription Service of English 1.6锛孯eal-time Transcription Service 1.10 released锛宎dapting to FunASR 1.0 model structure锛�([docs](runtime/readme.md))
 - 2024/03/05锛欰dded the Qwen-Audio and Qwen-Audio-Chat large-scale audio-text multimodal models, which have topped multiple audio domain leaderboards. These models support speech dialogue, [usage](examples/industrial_data_pretraining/qwen_audio).
 - 2024/03/05锛欰dded support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
@@ -84,10 +85,11 @@
 |                                   fsmn-vad <br> ( [猸怾(https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [馃](https://huggingface.co/funasr/fsmn-vad) )                                   |               voice activity detection                | 5000 hours, Mandarin and English |    0.4M    | 
 |                                     fa-zh <br> ( [猸怾(https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [馃](https://huggingface.co/funasr/fa-zh) )                                     |                 timestamp prediction                  |       5000 hours, Mandarin       |    38M     | 
 |                                       cam++ <br> ( [猸怾(https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [馃](https://huggingface.co/funasr/campplus) )                                        |           speaker verification/diarization            |            5000 hours            |    7.2M    | 
-|                                                  Whisper-large-v2 <br> ([猸怾(https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary)  [馃崁](https://github.com/openai/whisper) )                                                  |  speech recognition, with timestamps, non-streaming   |          multilingual      |    1550 M    |
-|                                                Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary)  [馃崁](https://github.com/openai/whisper) )                                                 |  speech recognition, with timestamps, non-streaming   |          multilingual            |    1550 M    |
-|                                         Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio) )                                         |      audio-text multimodal models (pretraining)       |     multilingual      |  8B  |
-|                   Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) )                                                |          audio-text multimodal models (chat)          |     multilingual      |  8B  |
+|                                                  Whisper-large-v2 <br> ([猸怾(https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary)  [馃崁](https://github.com/openai/whisper) )                                                  |  speech recognition, with timestamps, non-streaming   |           multilingual           |    1550 M    |
+|                                                Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary)  [馃崁](https://github.com/openai/whisper) )                                                 |  speech recognition, with timestamps, non-streaming   |           multilingual           |    1550 M    |
+|                                         Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio) )                                         |      audio-text multimodal models (pretraining)       |           multilingual           |  8B  |
+|                   Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) )                                                |          audio-text multimodal models (chat)          |           multilingual           |  8B  |
+|                        emotion2vec+large <br> ([猸怾(https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)  [馃](https://huggingface.co/emotion2vec/emotion2vec_plus_large) )                        |              speech emotion recongintion              |           40000 hours            |  300M  |
 
 
 
diff --git a/README_zh.md b/README_zh.md
index 44f92e6..d3c9ff6 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -29,6 +29,7 @@
 
 <a name="鏈�鏂板姩鎬�"></a>
 ## 鏈�鏂板姩鎬�
+- 2024/05/15锛氭柊澧炲姞鎯呮劅璇嗗埆妯″瀷锛孾emotion2vec+large](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)锛孾emotion2vec+base](https://modelscope.cn/models/iic/emotion2vec_plus_base/summary)锛孾emotion2vec+seed](https://modelscope.cn/models/iic/emotion2vec_plus_seed/summary)锛岃緭鍑烘儏鎰熺被鍒负锛氱敓姘�/angry锛屽紑蹇�/happy锛屼腑绔�/neutral锛岄毦杩�/sad銆�
 - 2024/05/15: 涓枃绂荤嚎鏂囦欢杞啓鏈嶅姟 4.5銆佽嫳鏂囩绾挎枃浠惰浆鍐欐湇鍔� 1.6銆佷腑鏂囧疄鏃惰闊冲惉鍐欐湇鍔� 1.10 鍙戝竷锛岄�傞厤FunASR 1.0妯″瀷缁撴瀯锛涜缁嗕俊鎭弬闃�([閮ㄧ讲鏂囨。](runtime/readme_cn.md))
 - 2024/03/05锛氭柊澧炲姞Qwen-Audio涓嶲wen-Audio-Chat闊抽鏂囨湰妯℃�佸ぇ妯″瀷锛屽湪澶氫釜闊抽棰嗗煙娴嬭瘯姒滃崟鍒锋锛屼腑鏀寔璇煶瀵硅瘽锛岃缁嗙敤娉曡 [绀轰緥](examples/industrial_data_pretraining/qwen_audio)銆�
 - 2024/03/05锛氭柊澧炲姞Whisper-large-v3妯″瀷鏀寔锛屽璇█璇煶璇嗗埆/缈昏瘧/璇璇嗗埆锛屾敮鎸佷粠 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)浠撳簱涓嬭浇锛屼篃鏀寔浠� [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)浠撳簱涓嬭浇妯″瀷銆�
@@ -75,19 +76,20 @@
 锛堟敞锛氣瓙 琛ㄧずModelScope妯″瀷浠撳簱锛岎煠� 琛ㄧずHuggingface妯″瀷浠撳簱锛岎煃�琛ㄧずOpenAI妯″瀷浠撳簱锛�
 
 
-|                                                                                                     妯″瀷鍚嶅瓧                                                                                                      |        浠诲姟璇︽儏        |     璁粌鏁版嵁     | 鍙傛暟閲�  | 
-|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:------------:|:----:|
-|    paraformer-zh <br> ([猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)  [馃](https://huggingface.co/funasr/paraformer-tp) )    |  璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃�   |  60000灏忔椂锛屼腑鏂�  | 220M |
-| paraformer-zh-streaming <br> ( [猸怾(https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [馃](https://huggingface.co/funasr/paraformer-zh-streaming) ) |      璇煶璇嗗埆锛屽疄鏃�       |  60000灏忔椂锛屼腑鏂�  | 220M |
-|         paraformer-en <br> ( [猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [馃](https://huggingface.co/funasr/paraformer-en) )         |      璇煶璇嗗埆锛岄潪瀹炴椂      |  50000灏忔椂锛岃嫳鏂�  | 220M |
-|                      conformer-en <br> ( [猸怾(https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [馃](https://huggingface.co/funasr/conformer-en) )                      |      璇煶璇嗗埆锛岄潪瀹炴椂      |  50000灏忔椂锛岃嫳鏂�  | 220M |
-|                        ct-punc <br> ( [猸怾(https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [馃](https://huggingface.co/funasr/ct-punc) )                         |        鏍囩偣鎭㈠        |  100M锛屼腑鏂囦笌鑻辨枃  | 1.1B | 
-|                            fsmn-vad <br> ( [猸怾(https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [馃](https://huggingface.co/funasr/fsmn-vad) )                             |     璇煶绔偣妫�娴嬶紝瀹炴椂      | 5000灏忔椂锛屼腑鏂囦笌鑻辨枃 | 0.4M | 
-|                              fa-zh <br> ( [猸怾(https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [馃](https://huggingface.co/funasr/fa-zh) )                               |      瀛楃骇鍒椂闂存埑棰勬祴      |  50000灏忔椂锛屼腑鏂�  | 38M  |
-|                                 cam++ <br> ( [猸怾(https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [馃](https://huggingface.co/funasr/campplus) )                                 |      璇磋瘽浜虹‘璁�/鍒嗗壊      |    5000灏忔椂    | 7.2M | 
-|                                     Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary)  [馃崁](https://github.com/openai/whisper) )                                      |  璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃�   |     澶氳瑷�      |  1550 M  |
-|                                         Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio) )                                         |  闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛堥璁粌锛�   |     澶氳瑷�      |  8B  |
-|                   Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) )                                                | 闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛坈hat鐗堟湰锛� |     澶氳瑷�      |  8B  |
+|                                                                                                     妯″瀷鍚嶅瓧                                                                                                      |        浠诲姟璇︽儏        |      璁粌鏁版嵁      |  鍙傛暟閲�   | 
+|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:--------------:|:------:|
+|    paraformer-zh <br> ([猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)  [馃](https://huggingface.co/funasr/paraformer-tp) )    |  璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃�   |   60000灏忔椂锛屼腑鏂�   |  220M  |
+| paraformer-zh-streaming <br> ( [猸怾(https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [馃](https://huggingface.co/funasr/paraformer-zh-streaming) ) |      璇煶璇嗗埆锛屽疄鏃�       |   60000灏忔椂锛屼腑鏂�   |  220M  |
+|         paraformer-en <br> ( [猸怾(https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [馃](https://huggingface.co/funasr/paraformer-en) )         |      璇煶璇嗗埆锛岄潪瀹炴椂      |   50000灏忔椂锛岃嫳鏂�   |  220M  |
+|                      conformer-en <br> ( [猸怾(https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [馃](https://huggingface.co/funasr/conformer-en) )                      |      璇煶璇嗗埆锛岄潪瀹炴椂      |   50000灏忔椂锛岃嫳鏂�   |  220M  |
+|                        ct-punc <br> ( [猸怾(https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [馃](https://huggingface.co/funasr/ct-punc) )                         |        鏍囩偣鎭㈠        |   100M锛屼腑鏂囦笌鑻辨枃   |  1.1B  | 
+|                            fsmn-vad <br> ( [猸怾(https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [馃](https://huggingface.co/funasr/fsmn-vad) )                             |     璇煶绔偣妫�娴嬶紝瀹炴椂      |  5000灏忔椂锛屼腑鏂囦笌鑻辨枃  |  0.4M  | 
+|                              fa-zh <br> ( [猸怾(https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [馃](https://huggingface.co/funasr/fa-zh) )                               |      瀛楃骇鍒椂闂存埑棰勬祴      |   50000灏忔椂锛屼腑鏂�   |  38M   |
+|                                 cam++ <br> ( [猸怾(https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [馃](https://huggingface.co/funasr/campplus) )                                 |      璇磋瘽浜虹‘璁�/鍒嗗壊      |     5000灏忔椂     |  7.2M  | 
+|                                     Whisper-large-v3 <br> ([猸怾(https://www.modelscope.cn/models/iic/Whisper-large-v3/summary)  [馃崁](https://github.com/openai/whisper) )                                      |  璇煶璇嗗埆锛屽甫鏃堕棿鎴宠緭鍑猴紝闈炲疄鏃�   |      澶氳瑷�       | 1550 M |
+|                                         Qwen-Audio <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio) )                                         |  闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛堥璁粌锛�   |      澶氳瑷�       |   8B   |
+|                                 Qwen-Audio-Chat <br> ([猸怾(examples/industrial_data_pretraining/qwen_audio/demo_chat.py)  [馃](https://huggingface.co/Qwen/Qwen-Audio-Chat) )                                  | 闊抽鏂囨湰澶氭ā鎬佸ぇ妯″瀷锛坈hat鐗堟湰锛� |      澶氳瑷�       |   8B   |
+|                        emotion2vec+large <br> ([猸怾(https://modelscope.cn/models/iic/emotion2vec_plus_large/summary)  [馃](https://huggingface.co/emotion2vec/emotion2vec_plus_large) )                        |    鎯呮劅璇嗗埆妯″瀷          | 40000灏忔椂锛�4绉嶆儏鎰熺被鍒� |  300M  |
 
 <a name="蹇�熷紑濮�"></a>
 ## 蹇�熷紑濮�
diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index f33dfee..71f69bb 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -6,14 +6,20 @@
 from funasr import AutoModel
 
 # model="iic/emotion2vec_base"
+# model="iic/emotion2vec_base_finetuned"
+# model="iic/emotion2vec_plus_seed"
+# model="iic/emotion2vec_plus_base"
+model = "iic/emotion2vec_plus_large"
+
 model = AutoModel(
-    model="iic/emotion2vec_base_finetuned",
+    model=model,
     # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
     # vad_model_revision="master",
     # vad_kwargs={"max_single_segment_time": 2000},
 )
 
 wav_file = f"{model.model_path}/example/test.wav"
+
 res = model.generate(
     wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False
 )
diff --git a/examples/industrial_data_pretraining/sense_voice/demo.py b/examples/industrial_data_pretraining/sense_voice/demo.py
index ed583f0..5303999 100644
--- a/examples/industrial_data_pretraining/sense_voice/demo.py
+++ b/examples/industrial_data_pretraining/sense_voice/demo.py
@@ -7,8 +7,8 @@
 
 model = AutoModel(
     model="/Users/zhifu/Downloads/modelscope_models/SenseVoiceModelscope",
-    vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-    vad_kwargs={"max_single_segment_time": 30000},
+    # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+    # vad_kwargs={"max_single_segment_time": 30000},
 )
 
 
@@ -21,6 +21,7 @@
     "language": "auto",
     "fp16": True,
     "gain_event": True,
+    "beam_size": 5,
 }
 
 res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions)
diff --git a/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
index e063e1f..ce4bdf8 100644
--- a/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
+++ b/examples/industrial_data_pretraining/sense_voice/demo_fsmn.py
@@ -21,6 +21,7 @@
     "language": "auto",
     "fp16": True,
     "gain_event": True,
+    "beam_size": 5,
 }
 
 res = model.generate(input=input_wav, batch_size_s=0, DecodingOptions=DecodingOptions, beam_size=5)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 7695e51..c3556d1 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -223,6 +223,7 @@
 
             torch.cuda.empty_cache()
 
+        trainer.start_data_split_i = 0
         trainer.validate_epoch(
             model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
         )
@@ -240,6 +241,8 @@
             f"estimated to finish {trainer.max_epoch} "
             f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
         )
+        trainer.train_acc_avg = 0.0
+        trainer.train_loss_avg = 0.0
 
     if trainer.rank == 0:
         average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
new file mode 100644
index 0000000..e4db533
--- /dev/null
+++ b/funasr/bin/train_ds.py
@@ -0,0 +1,241 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
+import os
+import sys
+import torch
+import torch.nn as nn
+import hydra
+import logging
+import time
+import argparse
+from io import BytesIO
+
+from contextlib import nullcontext
+import torch.distributed as dist
+
+from omegaconf import DictConfig, OmegaConf
+from torch.cuda.amp import autocast, GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from funasr.train_utils.average_nbest_models import average_checkpoints
+
+from funasr.register import tables
+from funasr.optimizers import optim_classes
+from funasr.train_utils.trainer_ds import Trainer
+from funasr.schedulers import scheduler_classes
+from funasr.train_utils.initialize import initialize
+from funasr.download.download_from_hub import download_model
+from funasr.models.lora.utils import mark_only_lora_as_trainable
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+from funasr.utils.misc import prepare_model_dir
+from funasr.train_utils.model_summary import model_summary
+from funasr import AutoModel
+
+try:
+    import deepspeed
+except:
+    deepspeed = None
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+    if kwargs.get("debug", False):
+        import pdb
+
+        pdb.set_trace()
+
+    assert "model" in kwargs
+    if "model_conf" not in kwargs:
+        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
+        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
+
+    main(**kwargs)
+
+
+def main(**kwargs):
+
+    # set random seed
+    set_all_random_seed(kwargs.get("seed", 0))
+    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
+    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
+    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+    # open tf32
+    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
+
+    rank = int(os.environ.get("RANK", 0))
+    local_rank = int(os.environ.get("LOCAL_RANK", 0))
+    world_size = int(os.environ.get("WORLD_SIZE", 1))
+
+    if local_rank == 0:
+        tables.print()
+
+    use_ddp = world_size > 1
+    use_fsdp = kwargs.get("use_fsdp", False)
+    use_deepspeed = kwargs.get("use_deepspeed", False)
+    if use_deepspeed:
+        logging.info(f"use_deepspeed: {use_deepspeed}")
+        deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl"))
+    elif use_ddp or use_fsdp:
+        logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}")
+        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
+        torch.cuda.set_device(local_rank)
+
+    logging.info("Build model, frontend, tokenizer")
+    device = kwargs.get("device", "cuda")
+    kwargs["device"] = "cpu"
+    model = AutoModel(**kwargs)
+
+    # save config.yaml
+    if rank == 0:
+        prepare_model_dir(**kwargs)
+
+    # parse kwargs
+    kwargs = model.kwargs
+    kwargs["device"] = device
+    tokenizer = kwargs["tokenizer"]
+    frontend = kwargs["frontend"]
+    model = model.model
+    del kwargs["model"]
+
+    # freeze_param
+    freeze_param = kwargs.get("freeze_param", None)
+    if freeze_param is not None:
+        if "," in freeze_param:
+            freeze_param = eval(freeze_param)
+        if not isinstance(freeze_param, (list, tuple)):
+            freeze_param = (freeze_param,)
+        logging.info("freeze_param is not None: %s", freeze_param)
+        for t in freeze_param:
+            for k, p in model.named_parameters():
+                if k.startswith(t + ".") or k == t:
+                    logging.info(f"Setting {k}.requires_grad = False")
+                    p.requires_grad = False
+    if local_rank == 0:
+        logging.info(f"{model_summary(model)}")
+
+    trainer = Trainer(
+        rank=rank,
+        local_rank=local_rank,
+        world_size=world_size,
+        use_ddp=use_ddp,
+        use_fsdp=use_fsdp,
+        device=kwargs["device"],
+        output_dir=kwargs.get("output_dir", "./exp"),
+        **kwargs.get("train_conf"),
+    )
+
+    model = trainer.warp_model(model)
+
+    kwargs["device"] = next(model.parameters()).device
+    trainer.device = kwargs["device"]
+
+    # optim
+    logging.info("Build optim")
+    optim = kwargs.get("optim", "adam")
+    assert optim in optim_classes
+    optim_class = optim_classes.get(optim)
+    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
+
+    # scheduler
+    logging.info("Build scheduler")
+    scheduler = kwargs.get("scheduler", "warmuplr")
+    assert scheduler in scheduler_classes
+    scheduler_class = scheduler_classes.get(scheduler)
+    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
+
+    if use_deepspeed:
+        args = OmegaConf.create({"deepspeed_config": kwargs.get("deepspeed_config", "")})
+        model, optimizer, _, scheduler = deepspeed.initialize(
+            args=args,
+            model=model,
+            optimizer=optim,
+            lr_scheduler=scheduler,
+            model_parameters=model.parameters(),
+        )
+
+    # dataset
+    logging.info("Build dataloader")
+    dataloader_class = tables.dataloader_classes.get(
+        kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
+    )
+    dataloader = dataloader_class(**kwargs)
+    # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
+
+    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
+
+    trainer.resume_checkpoint(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        scaler=scaler,
+    )
+
+    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
+    os.makedirs(tensorboard_dir, exist_ok=True)
+    try:
+        from tensorboardX import SummaryWriter
+
+        writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
+    except:
+        writer = None
+
+    dataloader_tr, dataloader_val = None, None
+    for epoch in range(trainer.start_epoch, trainer.max_epoch):
+        time1 = time.perf_counter()
+
+        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
+            dataloader_tr, dataloader_val = dataloader.build_iter(
+                epoch, data_split_i=data_split_i, start_step=trainer.start_step
+            )
+
+            trainer.train_epoch(
+                model=model,
+                optim=optim,
+                scheduler=scheduler,
+                scaler=scaler,
+                dataloader_train=dataloader_tr,
+                dataloader_val=dataloader_val,
+                epoch=epoch,
+                writer=writer,
+                data_split_i=data_split_i,
+                data_split_num=dataloader.data_split_num,
+                start_step=trainer.start_step,
+            )
+            trainer.start_step = 0
+
+            torch.cuda.empty_cache()
+
+        trainer.start_data_split_i = 0
+        trainer.validate_epoch(
+            model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
+        )
+        scheduler.step()
+        trainer.step_in_epoch = 0
+        trainer.save_checkpoint(
+            epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
+        )
+
+        time2 = time.perf_counter()
+        time_escaped = (time2 - time1) / 3600.0
+        logging.info(
+            f"rank: {local_rank}, "
+            f"time_escaped_epoch: {time_escaped:.3f} hours, "
+            f"estimated to finish {trainer.max_epoch} "
+            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
+        )
+        trainer.train_acc_avg = 0.0
+        trainer.train_loss_avg = 0.0
+
+    if trainer.rank == 0:
+        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
+
+    trainer.close()
+
+
+if __name__ == "__main__":
+    main_hydra()
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index e155cd7..528f593 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -146,10 +146,9 @@
         start_idx = self.rank * batches_per_rank
         end_idx = start_idx + batches_per_rank
         rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
-        if self.start_step > 0:
-            logging.info(
-                f"Warning, rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num_before: {end_idx-start_idx}, now: {len(rank_batches)}"
-            )
+        logging.info(
+            f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
+        )
         # Return an iterator over the batches for the current rank
         return iter(rank_batches)
 
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index ee2f13d..690a1c5 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -53,6 +53,12 @@
         self.prompt_ids_len = 0
         self.retry = kwargs.get("retry", 5)
 
+        self.permute = False
+        from funasr.frontends.whisper_frontend import WhisperFrontend
+
+        if isinstance(self.frontend, WhisperFrontend):
+            self.permute = True
+
     def get_source_len(self, index):
         item = self.index_ds[index]
         return self.index_ds.get_source_len(item)
@@ -92,7 +98,8 @@
 
             if speech_lengths > self.batch_size:
                 continue
-            speech = speech.permute(0, 2, 1)
+            if self.permute:
+                speech = speech.permute(0, 2, 1)
             target = item["target"]
             if self.preprocessor_text:
                 target = self.preprocessor_text(target)
@@ -100,8 +107,14 @@
             task = item.get("prompt", "<|ASR|>")
             text_language = item.get("text_language", "<|zh|>")
 
-            prompt = f"{self.sos}{task}{text_language}"
-            prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+            if isinstance(self.sos, str):
+                prompt = f"{self.sos}{task}{text_language}"
+                prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+            else:
+                prompt = f"{task}{text_language}"
+                prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
+                prompt_ids = [self.sos] + prompt_ids
+
             prompt_ids_len = len(prompt_ids) - 1  # [sos, task]
             self.prompt_ids_len = prompt_ids_len
 
@@ -110,7 +123,10 @@
             if target_ids_len > 200:
                 continue
 
-            eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
+            if isinstance(self.eos, str):
+                eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
+            else:
+                eos = [self.eos]
 
             ids = prompt_ids + target_ids + eos  # [sos, task, lid, text, eos]
             ids_lengths = len(ids)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 56e61e7..127d5a0 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -966,3 +966,415 @@
                     ibest_writer["text"][key[i]] = text
 
         return results, meta_data
+
+
+@tables.register("model_classes", "SenseVoiceSANM")
+class SenseVoiceSANM(nn.Module):
+
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        decoder: str = None,
+        decoder_conf: dict = None,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        decoder_class = tables.decoder_classes.get(decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            **decoder_conf,
+        )
+
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+
+        self.specaug = specaug
+
+        self.encoder = encoder
+
+        self.decoder = decoder
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        self.error_calculator = None
+
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+
+        batch_size, frames, _ = speech.shape
+        _, text_tokens = text.shape
+
+        if self.activation_checkpoint:
+            from torch.utils.checkpoint import checkpoint
+
+            encoder_out, encoder_out_lens = checkpoint(
+                self.encode, speech, speech_lengths, use_reentrant=False
+            )
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+        )
+
+        loss = loss_att
+        stats = {}
+        stats["acc"] = acc_att
+        stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
+        stats["batch_size_x_frames"] = frames * batch_size
+        stats["batch_size_real_frames"] = speech_lengths.sum().item()
+        stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
+        stats["batch_size_x_tokens"] = text_tokens * batch_size
+        stats["batch_size_real_tokens"] = text_lengths.sum().item()
+        stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
+        stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+        # Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        if isinstance(encoder_out, (tuple, list)):
+            encoder_out = encoder_out[0]
+
+        return encoder_out, encoder_out_lens
+
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+        stats = {}
+
+        # 1. Forward decoder
+        ys_pad[ys_pad == -1] = 0
+        decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+        if isinstance(decoder_out, (list, tuple)):
+            decoder_out = decoder_out[0]
+
+        # 2. Compute attention loss
+        mask = torch.ones_like(ys_pad) * (-1)
+        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
+        ys_pad_mask[ys_pad_mask == 0] = -1
+        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
+
+        with torch.no_grad():
+            preds = torch.argmax(decoder_out, -1)
+            acc_att = compute_accuracy(
+                preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
+            )
+
+        return loss_att, acc_att, None, None
+
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
+        from .search import BeamSearch
+
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+        # 1. Build ASR model
+        scorers = {}
+
+        scorers.update(
+            decoder=self.decoder,
+            length_bonus=LengthBonus(self.vocab_size),
+        )
+
+        weights = dict(
+            decoder=1.0,
+            ctc=0.0,
+            lm=0.0,
+            ngram=0.0,
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearch(
+            beam_size=kwargs.get("beam_size", 5),
+            weights=weights,
+            scorers=scorers,
+            sos=None,
+            eos=None,
+            vocab_size=self.vocab_size,
+            token_list=None,
+            pre_beam_score_key="full",
+        )
+
+        self.beam_search = beam_search
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+
+        # init beamsearch
+        if not hasattr(self, "beam_search") or self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
+
+        if frontend is None and not hasattr(self, "frontend"):
+            frontend_class = tables.frontend_classes.get("WhisperFrontend")
+            frontend = frontend_class(
+                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
+            )
+            self.frontend = frontend
+        else:
+            frontend = frontend if frontend is not None else self.frontend
+
+        meta_data = {}
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
+
+            if (
+                isinstance(kwargs.get("data_type", None), (list, tuple))
+                and len(kwargs.get("data_type", [])) > 1
+            ):
+                audio_sample_list, text_token_int_list = audio_sample_list
+                text_token_int = text_token_int_list[0]
+            else:
+                text_token_int = None
+
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
+            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])[0, :, :]
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        DecodingOptions = kwargs.get("DecodingOptions", {})
+        task = DecodingOptions.get("task", "ASR")
+        if isinstance(task, str):
+            task = [task]
+        task = "".join([f"<|{x}|>" for x in task])
+        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+
+        language = DecodingOptions.get("language", None)
+        language = None if language == "auto" else language
+
+        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+        sos_int = tokenizer.encode(sos, allowed_special="all")
+        eos = kwargs.get("model_conf").get("eos")
+        eos_int = tokenizer.encode(eos, allowed_special="all")
+        self.beam_search.sos = sos_int
+        self.beam_search.eos = eos_int[0]
+
+        # Paramterts for rich decoding
+        self.beam_search.emo_unk = tokenizer.encode(
+            DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
+        )[0]
+        self.beam_search.emo_unk_score = 1
+        self.beam_search.emo_tokens = tokenizer.encode(
+            DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
+            allowed_special="all",
+        )
+        self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
+
+        self.beam_search.event_bg_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_ed_token = tokenizer.encode(
+            DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
+            allowed_special="all",
+        )
+        self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
+
+        encoder_out, encoder_out_lens = self.encode(
+            speech[None, :, :].permute(0, 2, 1), speech_lengths
+        )
+
+        if text_token_int is not None:
+            i = 0
+            results = []
+            ibest_writer = None
+            if kwargs.get("output_dir") is not None:
+                if not hasattr(self, "writer"):
+                    self.writer = DatadirWriter(kwargs.get("output_dir"))
+                ibest_writer = self.writer[f"1best_recog"]
+
+            # 1. Forward decoder
+            ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
+                None, :
+            ]
+            ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
+                kwargs["device"]
+            )[None, :]
+            decoder_out = self.model.decoder(
+                x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+            )
+
+            token_int = decoder_out.argmax(-1)[0, :].tolist()
+            text = tokenizer.decode(token_int)
+
+            result_i = {"key": key[i], "text": text}
+            results.append(result_i)
+
+            if ibest_writer is not None:
+                # ibest_writer["token"][key[i]] = " ".join(token)
+                ibest_writer["text"][key[i]] = text
+            return results, meta_data
+
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0],
+            maxlenratio=kwargs.get("maxlenratio", 0.0),
+            minlenratio=kwargs.get("minlenratio", 0.0),
+        )
+
+        nbest_hyps = nbest_hyps[: self.nbest]
+
+        results = []
+        b, n, d = encoder_out.size()
+        for i in range(b):
+
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+
+                # # remove blank symbol id, which is assumed to be 0
+                # token_int = list(
+                #     filter(
+                #         lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                #     )
+                # )
+
+                # Change integer-ids to tokens
+                # token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.decode(token_int)
+
+                result_i = {"key": key[i], "text": text}
+                results.append(result_i)
+
+                if ibest_writer is not None:
+                    # ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text
+
+        return results, meta_data
diff --git a/funasr/tokenizer/sentencepiece_tokenizer.py b/funasr/tokenizer/sentencepiece_tokenizer.py
index ff4b3a2..1be1b81 100644
--- a/funasr/tokenizer/sentencepiece_tokenizer.py
+++ b/funasr/tokenizer/sentencepiece_tokenizer.py
@@ -20,6 +20,7 @@
         # "TypeError: can't pickle SwigPyObject objects",
         # when giving it as argument of "multiprocessing.Process()".
         self.sp = None
+        self._build_sentence_piece_processor()
 
     def __repr__(self):
         return f'{self.__class__.__name__}(model="{self.bpemodel}")'
@@ -38,10 +39,13 @@
         self._build_sentence_piece_processor()
         return self.sp.DecodePieces(list(tokens))
 
-    def encode(self, line: str) -> List[int]:
+    def encode(self, line: str, **kwargs) -> List[int]:
         self._build_sentence_piece_processor()
         return self.sp.EncodeAsIds(line)
 
-    def decode(self, line: List[int]):
+    def decode(self, line: List[int], **kwargs):
         self._build_sentence_piece_processor()
         return self.sp.DecodeIds(line)
+
+    def get_vocab_size(self):
+        return self.sp.GetPieceSize()
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 01e2924..50f99f0 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -382,8 +382,6 @@
                     ):
                         torch.cuda.empty_cache()
 
-                time3 = time.perf_counter()
-                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval
                 stats = {k: v for k, v in stats.items() if v is not None}
                 if self.use_ddp or self.use_fsdp:
@@ -398,34 +396,28 @@
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
+                # loss *= self.world_size
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss / accum_grad
+
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 if self.use_fp16:
                     scaler.scale(loss).backward()
                 else:
                     loss.backward()
                 time4 = time.perf_counter()
-                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+                speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
 
                 self.train_loss_avg = (
-                    self.train_loss_avg * (self.step_in_epoch - 1) + loss.detach().cpu().item()
-                ) / self.step_in_epoch
+                    self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
+                    + loss.detach().cpu().item()
+                ) / (batch_idx + kwargs.get("start_step", 0) + 1)
                 if "acc" in stats:
                     self.train_acc_avg = (
-                        self.train_acc_avg * (self.step_in_epoch - 1)
+                        self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
                         + stats["acc"].detach().cpu().item()
-                    ) / self.step_in_epoch
-                if self.use_ddp or self.use_fsdp:
-                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
-                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
-                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                    ) / (batch_idx + kwargs.get("start_step", 0) + 1)
 
             # Perform an optimizer step only after accumulating enough gradients
             if (batch_idx + 1) % accum_grad == 0:
@@ -454,8 +446,22 @@
                 scheduler.step()
                 # Clear gradients for the next accumulation stage
                 optim.zero_grad(set_to_none=True)
-                total_time = f"{time.perf_counter() - time5:0.3f}"
+
+                if self.use_ddp or self.use_fsdp:
+                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
+                total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
                 time5 = time.perf_counter()
+
                 speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
 
                 speed_stats["total_time"] = total_time
@@ -662,9 +668,9 @@
                 f"data_slice: {data_split_i}/{data_split_num}, "
                 f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
                 f"(loss_avg_rank: {loss:.3f}), "
-                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
-                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "
-                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
+                f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
+                f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
                 f"(lr: {lr:.3e}), "
                 f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                 f"{speed_stats}, "
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
new file mode 100644
index 0000000..7188921
--- /dev/null
+++ b/funasr/train_utils/trainer_ds.py
@@ -0,0 +1,800 @@
+import math
+import os
+import time
+import torch
+import logging
+from tqdm import tqdm
+from datetime import datetime
+import torch.distributed as dist
+from torch.cuda.amp import autocast, GradScaler
+from contextlib import nullcontext, contextmanager
+from pathlib import Path
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from funasr.train_utils.device_funcs import to_device
+from funasr.train_utils.recursive_op import recursive_average
+from funasr.train_utils.average_nbest_models import average_checkpoints
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+
+try:
+    import wandb
+except:
+    wandb = None
+
+
+@contextmanager
+def maybe_autocast(enabled):
+    if enabled:
+        with autocast():
+            yield
+    else:
+        yield
+
+
+class Trainer:
+    """
+    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
+    and optionally resuming from a saved checkpoint.
+
+    Attributes:
+        max_epoch (int): Maximum number of epochs for training.
+        model (torch.nn.Module): The model to be trained.
+        optim (torch.optim.Optimizer): The optimizer to use for training.
+        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
+        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
+        output_dir (str): Directory where model checkpoints will be saved.
+        resume (str, optional): Path to a checkpoint to resume training from.
+    """
+
+    def __init__(
+        self,
+        rank=0,
+        local_rank=0,
+        world_size=1,
+        use_ddp: bool = False,
+        use_fsdp: bool = False,
+        use_fp16: bool = False,
+        use_deepspeed: bool = False,
+        output_dir: str = "./",
+        **kwargs,
+    ):
+        """
+        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
+
+        Args:
+            model (torch.nn.Module): The model to be trained.
+            optim (torch.optim.Optimizer): The optimizer to use for training.
+            scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+            dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
+            dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
+            **kwargs: Additional keyword arguments:
+                      max_epoch (int): The maximum number of epochs for training.
+                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
+                      resume (str, optional): The file path to a checkpoint to resume training from.
+        """
+        self.rank = kwargs.get("rank", 0)
+        self.local_rank = local_rank
+        self.world_size = world_size
+        self.use_ddp = use_ddp
+        self.use_fsdp = use_fsdp
+        self.use_deepspeed = use_deepspeed
+        self.device = kwargs.get("device", "cuda")
+
+        self.output_dir = output_dir
+        if not os.path.exists(self.output_dir):
+            os.makedirs(self.output_dir, exist_ok=True)
+        self.resume = kwargs.get("resume", True)
+        self.start_epoch = 0
+        self.max_epoch = kwargs.get("max_epoch", 100)
+
+        # self.kwargs = kwargs
+        self.log_interval = kwargs.get("log_interval", 50)
+        self.batch_total = 0
+        self.use_fp16 = use_fp16
+        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+        self.validate_interval = kwargs.get("validate_interval", 5000)
+        self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
+        self.avg_keep_nbest_models_type = kwargs.get("avg_keep_nbest_models_type", "acc")
+        self.avg_nbest_model = kwargs.get("avg_nbest_model", 10)
+        self.accum_grad = kwargs.get("accum_grad", 1)
+        self.grad_clip = kwargs.get("grad_clip", 10.0)
+        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
+
+        self.train_acc_avg = 0.0
+        self.train_loss_avg = 0.0
+        self.val_acc_avg = 0.0
+        self.val_loss_avg = 0.0
+        self.best_acc_idx = 0
+        self.saved_ckpts = {}
+        self.step_or_epoch = -1
+        self.best_step_or_epoch = ""
+        self.val_acc_step_or_eoch = {}
+        self.val_loss_step_or_eoch = {}
+
+        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+        self.start_data_split_i = 0
+        self.start_step = 0
+        self.step_in_epoch = 0
+        self.use_wandb = kwargs.get("use_wandb", False)
+        if self.use_wandb:
+            wandb.login(key=kwargs.get("wandb_token"))
+            wandb.init(
+                config=kwargs,
+                project=kwargs.get("wandb_project", "my_project"),
+                entity=kwargs.get("wandb_team", "my_team"),
+                name=kwargs.get("wandb_exp_name", "my_exp"),
+                dir=output_dir,
+                job_type="training",
+                reinit=True,
+            )
+
+    def save_checkpoint(
+        self,
+        epoch,
+        step=None,
+        model=None,
+        optim=None,
+        scheduler=None,
+        scaler=None,
+        step_in_epoch=None,
+        **kwargs,
+    ):
+        """
+        Saves a checkpoint containing the model's state, the optimizer's state,
+        and the scheduler's state at the end of the given epoch. This method is
+        intended to be called at the end of each epoch to save the training progress.
+
+        Args:
+            epoch (int): The epoch number at which the checkpoint is being saved.
+        """
+
+        step_in_epoch = None if step is None else step_in_epoch
+        if self.rank == 0:
+            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+            # self.step_or_epoch += 1
+            state = {
+                "epoch": epoch,
+                "state_dict": model.state_dict(),
+                "optimizer": optim.state_dict(),
+                "scheduler": scheduler.state_dict(),
+                "saved_ckpts": self.saved_ckpts,
+                "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
+                "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+                "best_step_or_epoch": self.best_step_or_epoch,
+                "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
+                "step": step,
+                "step_in_epoch": step_in_epoch,
+                "data_split_i": kwargs.get("data_split_i", 0),
+                "data_split_num": kwargs.get("data_split_num", 1),
+                "batch_total": self.batch_total,
+                "train_loss_avg": kwargs.get("train_loss_avg", 0),
+                "train_acc_avg": kwargs.get("train_acc_avg", 0),
+            }
+            step = step_in_epoch
+            if hasattr(model, "module"):
+                state["state_dict"] = model.module.state_dict()
+
+            if scaler:
+                state["scaler_state"] = scaler.state_dict()
+            # Create output directory if it does not exist
+            os.makedirs(self.output_dir, exist_ok=True)
+            if step is None:
+                ckpt_name = f"model.pt.ep{epoch}"
+            else:
+                ckpt_name = f"model.pt.ep{epoch}.{step}"
+            filename = os.path.join(self.output_dir, ckpt_name)
+            torch.save(state, filename)
+
+            logging.info(f"\nCheckpoint saved to {filename}\n")
+            latest = Path(os.path.join(self.output_dir, f"model.pt"))
+            torch.save(state, latest)
+            if self.best_step_or_epoch == "":
+                self.best_step_or_epoch = ckpt_name
+
+            if self.avg_keep_nbest_models_type == "acc":
+                if (
+                    self.val_acc_step_or_eoch[ckpt_name]
+                    >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+                ):
+                    self.best_step_or_epoch = ckpt_name
+                    best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
+                    torch.save(state, best_ckpt)
+                    logging.info(
+                        f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                    )
+                else:
+                    logging.info(
+                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                    )
+            elif self.avg_keep_nbest_models_type == "loss":
+                if (
+                    self.val_loss_step_or_eoch[ckpt_name]
+                    <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+                ):
+                    self.best_step_or_epoch = ckpt_name
+                    best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
+                    torch.save(state, best_ckpt)
+                    logging.info(
+                        f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                    )
+                else:
+                    logging.info(
+                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                    )
+            else:
+                print("Undo")
+            self.saved_ckpts[ckpt_name] = getattr(
+                self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+            )[ckpt_name]
+            if self.keep_nbest_models > 0:
+                if len(self.saved_ckpts) > self.keep_nbest_models:
+                    if self.avg_keep_nbest_models_type == "acc":
+                        key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+                    else:
+                        key = max(self.saved_ckpts, key=self.saved_ckpts.get)
+                    if key in self.saved_ckpts:
+                        del self.saved_ckpts[key]
+                    filename = os.path.join(self.output_dir, key)
+                    logging.info(f"Delete: {filename}")
+                    if os.path.exists(filename):
+                        os.remove(filename)
+
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+
+    def resume_checkpoint(
+        self,
+        model=None,
+        optim=None,
+        scheduler=None,
+        scaler=None,
+    ):
+        """
+        Resumes training from a checkpoint at the given file path.
+        Loads the model's state, the optimizer's state, and the scheduler's state.
+
+        Args:
+            resume_path (str): The file path to the checkpoint to resume from.
+        """
+        if self.resume:
+            ckpt = os.path.join(self.output_dir, "model.pt")
+            if os.path.isfile(ckpt):
+                checkpoint = torch.load(ckpt, map_location="cpu")
+                self.start_epoch = checkpoint["epoch"]
+                # self.model.load_state_dict(checkpoint['state_dict'])
+                src_state = checkpoint["state_dict"]
+                dst_state = model.state_dict()
+                for k in dst_state.keys():
+                    if not k.startswith("module.") and "module." + k in src_state.keys():
+                        k_ddp = "module." + k
+                    elif k.startswith("module.") and "module." + k not in src_state.keys():
+                        k_ddp = k.replace("module.", "", 1)
+                    else:
+                        k_ddp = k
+                    if k_ddp in src_state.keys():
+                        dst_state[k] = src_state[k_ddp]
+                    else:
+                        print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+
+                model.load_state_dict(dst_state)
+                optim.load_state_dict(checkpoint["optimizer"])
+                scheduler.load_state_dict(checkpoint["scheduler"])
+                if scaler is not None and "scaler_state" in checkpoint:
+                    scaler.load_state_dict(checkpoint["scaler_state"])
+
+                self.saved_ckpts = checkpoint["saved_ckpts"]
+                self.val_acc_step_or_eoch = (
+                    checkpoint["val_acc_step_or_eoch"]
+                    if "val_acc_step_or_eoch" in checkpoint
+                    else {}
+                )
+                self.val_loss_step_or_eoch = (
+                    checkpoint["val_loss_step_or_eoch"]
+                    if "val_loss_step_or_eoch" in checkpoint
+                    else {}
+                )
+                self.best_step_or_epoch = (
+                    checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
+                )
+                self.start_data_split_i = (
+                    checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
+                )
+                self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
+                self.start_step = checkpoint["step"] if "step" in checkpoint else 0
+                self.start_step = 0 if self.start_step is None else self.start_step
+                self.step_in_epoch = (
+                    checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
+                )
+                self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
+                print(checkpoint["train_acc_avg"])
+                self.train_acc_avg = (
+                    checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
+                )
+                self.train_loss_avg = (
+                    checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
+                )
+                model.to(self.device)
+                print(f"Checkpoint loaded successfully from '{ckpt}'")
+            else:
+                print(f"No checkpoint found at '{ckpt}', does not resume status!")
+
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+
+    def train_epoch(
+        self,
+        model=None,
+        optim=None,
+        scheduler=None,
+        scaler=None,
+        dataloader_train=None,
+        dataloader_val=None,
+        epoch=None,
+        writer=None,
+        **kwargs,
+    ):
+        """
+        Defines the training process for a single epoch with gradient accumulation.
+        Args:
+            epoch (int): The current epoch number.
+        """
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
+        model.train()
+
+        # Set the number of steps for gradient accumulation
+        accum_grad = self.accum_grad
+        # Initialize the gradient accumulation
+        optim.zero_grad()
+        speed_stats = {}
+
+        iterator_stop = torch.tensor(0).to(self.device)
+
+        dataloader_train.batch_sampler.set_epoch(epoch)
+        time_beg = time.perf_counter()
+        time5 = time_beg
+        for batch_idx, batch in enumerate(dataloader_train):
+            if self.use_ddp or self.use_fsdp:
+                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                if iterator_stop > 0:
+                    break
+            self.batch_total += 1
+            self.step_in_epoch += 1
+            time1 = time.perf_counter()
+            speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
+
+            batch = to_device(batch, self.device)
+
+            my_context = nullcontext
+            if self.use_ddp or self.use_fsdp:
+                my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context
+            with my_context():
+                time2 = time.perf_counter()
+                loss_dict = {}
+                self.forward_step(model, batch, loss_dict=loss_dict)
+
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+                self.backward_step(model, scaler, loss_dict=loss_dict)
+
+                time4 = time.perf_counter()
+                speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
+
+                # self.train_loss_avg = (
+                #     self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
+                #     + loss.detach().cpu().item()
+                # ) / (batch_idx + kwargs.get("start_step", 0) + 1)
+                # if "acc" in stats:
+                #     self.train_acc_avg = (
+                #         self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
+                #         + stats["acc"].detach().cpu().item()
+                #     ) / (batch_idx + kwargs.get("start_step", 0) + 1)
+
+            self.update_step(model, optim, scheduler, scaler, loss_dict)
+            # Perform an optimizer step only after accumulating enough gradients
+
+            if self.step_in_epoch % self.validate_interval == 0:
+                self.validate_epoch(
+                    model=model,
+                    dataloader_val=dataloader_val,
+                    epoch=epoch,
+                    writer=writer,
+                    step=batch_idx + 1,
+                    step_in_epoch=self.step_in_epoch,
+                )
+
+            if self.step_in_epoch % self.save_checkpoint_interval == 0:
+                self.save_checkpoint(
+                    epoch,
+                    model=model,
+                    optim=optim,
+                    scheduler=scheduler,
+                    scaler=scaler,
+                    step=batch_idx + 1,
+                    step_in_epoch=self.step_in_epoch,
+                    data_split_i=kwargs.get("data_split_i", 0),
+                    data_split_num=kwargs.get("data_split_num", 1),
+                    train_loss_avg=self.train_loss_avg,
+                    train_acc_avg=self.train_acc_avg,
+                )
+
+            time_beg = time.perf_counter()
+        else:
+            if self.use_ddp or self.use_fsdp:
+                iterator_stop.fill_(1)
+                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+            iterator_stop = torch.tensor(0).to(self.device)
+
+    def forward_step(self, model, batch, loss_dict={}):
+        with maybe_autocast(self.use_fp16):
+            retval = model(**batch)
+
+            if (
+                self.reset_gpu_cache
+                and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
+            ):
+                torch.cuda.empty_cache()
+
+        loss, stats, weight = retval
+        stats = {k: v for k, v in stats.items() if v is not None}
+        # if self.use_ddp or self.use_fsdp:
+        #     # Apply weighted averaging for loss and stats
+        #     loss = (loss * weight.type(loss.dtype)).sum()
+        #     # if distributed, this method can also apply all_reduce()
+        #     # stats, weight = recursive_average(stats, weight, distributed=True)
+        #     if self.use_ddp or self.use_fsdp:
+        #         dist.all_reduce(weight, op=dist.ReduceOp.SUM)
+        #     # Now weight is summation over all workers
+        #     loss /= weight.sum()  # shape:[1] -> shape:[]
+        #     # Multiply world_size because DistributedDataParallel
+        #     # automatically normalizes the gradient by world_size.
+        #     loss *= self.world_size
+        # loss *= self.world_size
+        # Scale the loss since we're not updating for every mini-batch
+
+        loss_dict["loss"] = loss
+        loss_dict["stats"] = stats
+        loss_dict["weight"] = weight
+
+    def backward_step(self, model, scaler, loss_dict={}):
+        loss = loss_dict["loss"]
+
+        if self.use_deepspeed:
+            scaled_loss = model.backward(loss)
+        else:
+            loss = loss / self.accum_grad
+            if self.use_fp16:
+                scaler.scale(loss).backward()
+            else:
+                loss.backward()
+
+    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
+        if (batch_idx + 1) % self.accum_grad == 0:
+            # Perform gradient clipping if it is set
+            if self.grad_clip > 0:
+                grad_norm = torch.nn.utils.clip_grad_norm_(
+                    model.parameters(),
+                    max_norm=self.grad_clip,
+                    norm_type=self.grad_clip_type,
+                )
+                if not torch.isfinite(grad_norm):
+                    logging.warning(f"The grad norm is {grad_norm}. Skipping updating the model.")
+                    optim.zero_grad()  # Reset gradients
+                    return
+
+            # Execute an optimization step (update model parameters)
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+            if self.use_fp16:
+                scaler.step(optim)
+                scaler.update()
+            else:
+                optim.step()
+            scheduler.step()
+            # Clear gradients for the next accumulation stage
+            optim.zero_grad(set_to_none=True)
+
+            if self.use_ddp or self.use_fsdp:
+                train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
+                    self.device
+                )
+                train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
+                    self.device
+                )
+                dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
+            total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
+            time5 = time.perf_counter()
+
+            speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
+
+            speed_stats["total_time"] = total_time
+            lr = scheduler.get_last_lr()[0]
+            batch_num_epoch = 1
+            if hasattr(dataloader_train, "__len__"):
+                batch_num_epoch = len(dataloader_train)
+            self.log(
+                epoch,
+                batch_idx,
+                log_step=batch_idx + kwargs.get("start_step", 0),
+                step_in_epoch=self.step_in_epoch,
+                batch_num_epoch=batch_num_epoch,
+                lr=lr,
+                loss=loss.detach().cpu().item(),
+                speed_stats=speed_stats,
+                stats=stats,
+                writer=writer,
+                tag="train",
+                data_split_i=kwargs.get("data_split_i", 0),
+                data_split_num=kwargs.get("data_split_num", 1),
+            )
+
+    def validate_epoch(
+        self,
+        model=None,
+        dataloader_val=None,
+        epoch=None,
+        writer=None,
+        **kwargs,
+    ):
+        """
+        Defines the validation process for a single epoch.
+        Should be implemented with the actual model validation steps.
+
+        Args:
+            epoch (int): The current epoch number.
+        """
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
+        model.eval()
+
+        with torch.no_grad():
+
+            speed_stats = {}
+            time5 = time.perf_counter()
+            iterator_stop = torch.tensor(0).to(self.device)
+            dataloader_val.batch_sampler.set_epoch(epoch)
+            for batch_idx, batch in enumerate(dataloader_val):
+                if self.use_ddp or self.use_fsdp:
+                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+                    if iterator_stop > 0:
+                        break
+                time1 = time.perf_counter()
+                speed_stats["data_load"] = f"{time1 - time5:0.3f}"
+                batch = to_device(batch, self.device)
+                time2 = time.perf_counter()
+                retval = model(**batch)
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+                loss, stats, weight = retval
+                stats = {k: v for k, v in stats.items() if v is not None}
+                if self.use_ddp or self.use_fsdp:
+                    # Apply weighted averaging for loss and stats
+                    loss = (loss * weight.type(loss.dtype)).sum()
+                    # if distributed, this method can also apply all_reduce()
+                    # stats, weight = recursive_average(stats, weight, distributed=True)
+                    if self.use_ddp or self.use_fsdp:
+                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
+                    # Now weight is summation over all workers
+                    loss /= weight.sum()  # shape:[1] -> shape:[]
+                    # Multiply world_size because DistributedDataParallel
+                    # automatically normalizes the gradient by world_size.
+                    loss *= self.world_size
+                # Scale the loss since we're not updating for every mini-batch
+                loss = loss
+                time4 = time.perf_counter()
+
+                self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
+                    batch_idx + 1
+                )
+                if "acc" in stats:
+                    self.val_acc_avg = (
+                        self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+                    ) / (batch_idx + 1)
+                if self.use_ddp or self.use_fsdp:
+                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+                time5 = time.perf_counter()
+                batch_num_epoch = 1
+                if hasattr(dataloader_val, "__len__"):
+                    batch_num_epoch = len(dataloader_val)
+                self.log(
+                    epoch,
+                    batch_idx,
+                    batch_num_epoch=batch_num_epoch,
+                    lr=0.0,
+                    loss=loss.detach().cpu().item(),
+                    speed_stats=speed_stats,
+                    stats=stats,
+                    writer=writer,
+                    tag="val",
+                )
+
+            else:
+                if self.use_ddp or self.use_fsdp:
+                    iterator_stop.fill_(1)
+                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+
+        if kwargs.get("step_in_epoch", None) is None:
+            ckpt_name = f"model.pt.ep{epoch}"
+        else:
+            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
+        self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
+        self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+        model.train()
+
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+            iterator_stop = torch.tensor(0).to(self.device)
+
+    def log(
+        self,
+        epoch=0,
+        batch_idx=0,
+        step_in_epoch=0,
+        batch_num_epoch=-1,
+        lr=0.0,
+        loss=0.0,
+        speed_stats=None,
+        stats=None,
+        writer=None,
+        tag="train",
+        data_split_i=0,
+        data_split_num=1,
+        log_step=None,
+        **kwargs,
+    ):
+
+        if (batch_idx + 1) % self.log_interval == 0:
+            batch_idx = log_step if log_step is not None else batch_idx
+            gpu_info = (
+                "GPU, memory: usage: {:.3f} GB, "
+                "peak: {:.3f} GB, "
+                "cache: {:.3f} GB, "
+                "cache_peak: {:.3f} GB".format(
+                    torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+                    torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
+                    torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
+                    torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
+                )
+            )
+
+            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
+            description = (
+                f"{tag}, "
+                f"rank: {self.rank}, "
+                f"epoch: {epoch}/{self.max_epoch}, "
+                f"data_slice: {data_split_i}/{data_split_num}, "
+                f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
+                f"(loss_avg_rank: {loss:.3f}), "
+                f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
+                f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
+                f"(lr: {lr:.3e}), "
+                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
+                f"{speed_stats}, "
+                f"{gpu_info}"
+            )
+            logging.info(description)
+
+            description_dict = {
+                f"rank{self.rank}_loss/{tag}": loss,
+                f"rank{self.rank}_lr/{tag}": lr,
+            }
+
+            if writer is not None:
+                writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
+                writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
+                for key, var in stats.items():
+                    writer.add_scalar(
+                        f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
+                    )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
+                for key, var in speed_stats.items():
+                    writer.add_scalar(
+                        f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
+                    )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
+            if self.use_wandb and wandb is not None:
+                wandb.log(
+                    description_dict,
+                    setp=self.batch_total,
+                )
+
+    def close(self, writer=None):
+
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+
+        if writer is not None:
+            writer.close()
+
+        if self.use_ddp or self.use_fsdp:
+            torch.distributed.destroy_process_group()
+
+    def warp_model(self, model, **kwargs):
+
+        if self.use_deepspeed:
+            from deepspeed.runtime.zero.stage_1_and_2 import (
+                estimate_zero2_model_states_mem_needs_all_live,
+            )
+            from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
+            from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+
+            local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
+            world_size = int(os.environ.get("WORLD_SIZE", 1))
+
+            # NOTE(xcsong): look in detail how the memory estimator API works:
+            #   https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
+            if int(os.environ.get("RANK", 0)) == 0:
+                logging.info("Estimating model states memory needs (zero2)...")
+                estimate_zero2_model_states_mem_needs_all_live(
+                    model,
+                    num_gpus_per_node=local_world_size,
+                    num_nodes=world_size // local_world_size,
+                )
+                logging.info("Estimating model states memory needs (zero3)...")
+                estimate_zero3_model_states_mem_needs_all_live(
+                    model,
+                    num_gpus_per_node=local_world_size,
+                    num_nodes=world_size // local_world_size,
+                )
+            device = None  # Init device later
+            pass  # Init DeepSpeed later
+
+        elif self.use_ddp:
+            local_rank = int(os.environ.get("LOCAL_RANK", 0))
+            model = model.cuda(local_rank)
+            model = DDP(
+                model,
+                device_ids=[local_rank],
+                find_unused_parameters=kwargs.get("train_conf", {}).get(
+                    "find_unused_parameters", False
+                ),
+            )
+        # elif self.use_fsdp:
+        #     # model = FSDP(model).cuda(local_rank)
+        #
+        #     def custom_auto_wrap_policy(
+        #         module: nn.Module,
+        #         recurse: bool,
+        #         nonwrapped_numel: int,
+        #         # Additional custom arguments
+        #         min_num_params: int = int(1e8),
+        #     ) -> bool:
+        #         # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+        #         is_large = unwrapped_params >= min_num_params
+        #         requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+        #         return is_large and requires_grad_uniform
+        #
+        #     # Configure a custom `min_num_params`
+        #     my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+        #     torch.cuda.set_device(local_rank)
+        #     model = FSDP(
+        #         model,
+        #         auto_wrap_policy=custom_auto_wrap_policy,
+        #         mixed_precision=None,
+        #         device_id=torch.cuda.current_device(),
+        #     )
+        else:
+            model = model.to(device=kwargs.get("device", "cuda"))
+
+        return model
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index 9f01955..4613cb3 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -70,14 +70,16 @@
 
     yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
     OmegaConf.save(config=kwargs, f=yaml_file)
-    print(kwargs)
+    logging.info(f"kwargs: {kwargs}")
     logging.info("config.yaml is saved to: %s", yaml_file)
 
-    # model_path = kwargs.get("model_path")
-    # if model_path is not None:
-    #     config_json = os.path.join(model_path, "configuration.json")
-    #     if os.path.exists(config_json):
-    #         shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json"))
+    model_path = kwargs.get("model_path", None)
+    if model_path is not None:
+        config_json = os.path.join(model_path, "configuration.json")
+        if os.path.exists(config_json):
+            shutil.copy(
+                config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json")
+            )
 
 
 def extract_filename_without_extension(file_path):

--
Gitblit v1.9.1