From 4482bbcbb912f699a4faecaafd65aa15aec64a51 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 11:49:30 +0800
Subject: [PATCH] train (#1521)

---
 funasr/datasets/large_datasets/utils/filter.py                         |   26 
 funasr/train_utils/trainer.py                                          |  501 ++++++----
 funasr/register.py                                                     |    1 
 funasr/datasets/large_datasets/utils/clipping.py                       |   40 
 examples/industrial_data_pretraining/paraformer/finetune_from_local.sh |   18 
 examples/industrial_data_pretraining/paraformer/finetune.sh            |   11 
 funasr/datasets/large_datasets/build_dataloader.py                     |   97 ++
 funasr/datasets/large_datasets/utils/padding.py                        |   74 +
 funasr/bin/train.py                                                    |  215 ++--
 funasr/datasets/large_datasets/abs_iter_factory.py                     |    9 
 funasr/datasets/large_datasets/datapipes/batch.py                      |  213 ++++
 funasr/datasets/large_datasets/datapipes/filter.py                     |   24 
 funasr/datasets/large_datasets/datapipes/map.py                        |   22 
 funasr/datasets/large_datasets/utils/__init__.py                       |    0 
 funasr/datasets/large_datasets/utils/low_frame_rate.py                 |   30 
 funasr/datasets/large_datasets/collate_fn.py                           |  196 ++++
 funasr/datasets/large_datasets/__init__.py                             |    0 
 funasr/datasets/dataloader_entry.py                                    |   38 
 funasr/datasets/large_datasets/dataset.py                              |  274 +++++
 funasr/utils/misc.py                                                   |   23 
 /dev/null                                                              |  242 -----
 funasr/datasets/large_datasets/utils/hotword_utils.py                  |   33 
 funasr/datasets/large_datasets/utils/tokenize.py                       |   95 ++
 funasr/datasets/audio_datasets/samplers.py                             |  550 ++++++-----
 funasr/datasets/large_datasets/datapipes/__init__.py                   |    0 
 25 files changed, 1,924 insertions(+), 808 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh
index e1273da..9fc8bf0 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune.sh
@@ -36,9 +36,14 @@
 ++model_revision="v2.0.4" \
 ++train_data_set_list="${train_data}" \
 ++valid_data_set_list="${val_data}" \
-++dataset_conf.batch_size=32 \
-++dataset_conf.batch_type="example" \
+++dataset_conf.batch_size=20000 \
+++dataset_conf.batch_type="token" \
 ++dataset_conf.num_workers=4 \
-++train_conf.max_epoch=20 \
+++train_conf.max_epoch=50 \
+++train_conf.log_interval=10 \
+++train_conf.resume=false \
+++train_conf.validate_interval=15 \
+++train_conf.save_checkpoint_interval=15 \
+++train_conf.keep_nbest_models=50 \
 ++optim_conf.lr=0.0002 \
 ++output_dir="${output_dir}" &> ${log_file}
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh b/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh
index dcbdf77..b883908 100644
--- a/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh
+++ b/examples/industrial_data_pretraining/paraformer/finetune_from_local.sh
@@ -59,13 +59,17 @@
 --config-name "${config_name}" \
 ++train_data_set_list="${train_data}" \
 ++valid_data_set_list="${val_data}" \
+++dataset_conf.batch_size=20000 \
+++dataset_conf.batch_type="token" \
+++dataset_conf.num_workers=4 \
+++train_conf.max_epoch=50 \
+++train_conf.log_interval=10 \
+++train_conf.resume=false \
+++train_conf.validate_interval=15 \
+++train_conf.save_checkpoint_interval=15 \
+++train_conf.keep_nbest_models=50 \
+++optim_conf.lr=0.0002 \
+++init_param="${init_param}" \
 ++tokenizer_conf.token_list="${tokens}" \
 ++frontend_conf.cmvn_file="${cmvn_file}" \
-++dataset_conf.batch_size=32 \
-++dataset_conf.batch_type="example" \
-++dataset_conf.num_workers=4 \
-++train_conf.max_epoch=20 \
-++optim_conf.lr=0.0002 \
-++train_conf.log_interval=1 \
-++init_param="${init_param}" \
 ++output_dir="${output_dir}" &> ${log_file}
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 3c93371..3f97f9e 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -4,15 +4,23 @@
 import os
 import sys
 import torch
+import torch.nn as nn
 import hydra
 import logging
+import time
 import argparse
 from io import BytesIO
+
+from contextlib import nullcontext
 import torch.distributed as dist
 from collections.abc import Sequence
 from omegaconf import DictConfig, OmegaConf
+from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from funasr.train_utils.average_nbest_models import average_checkpoints
 
 from funasr.register import tables
 from funasr.optimizers import optim_classes
@@ -23,10 +31,8 @@
 from funasr.models.lora.utils import mark_only_lora_as_trainable
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
-# from funasr.tokenizer.build_tokenizer import build_tokenizer
-# from funasr.tokenizer.token_id_converter import TokenIDConverter
-# from funasr.tokenizer.funtoken import build_tokenizer
-
+from funasr.utils.misc import prepare_model_dir
+from funasr import AutoModel
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
@@ -43,7 +49,6 @@
 
 
 def main(**kwargs):
-    print(kwargs)
     
     # set random seed
     set_all_random_seed(kwargs.get("seed", 0))
@@ -56,65 +61,29 @@
         tables.print()
     # Check if we are using DDP or FSDP
     use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
-    use_fsdp = kwargs.get("use_fsdp", None)
+    use_fsdp = kwargs.get("use_fsdp", False)
+    # use_ddp = False if use_fsdp else use_fsdp
     if use_ddp or use_fsdp:
         dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
         torch.cuda.set_device(local_rank)
+
+    logging.info("Build model, frontend, tokenizer")
+    device = kwargs.get("device", "cuda")
+    kwargs["device"] = "cpu"
+    model = AutoModel(**kwargs)
+    
     
     # save config.yaml
     if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
-        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
-        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
-        OmegaConf.save(config=kwargs, f=yaml_file)
-        logging.info("config.yaml is saved to: %s", yaml_file)
-
-    tokenizer = kwargs.get("tokenizer", None)
-    if tokenizer is not None:
-        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
-        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
-        kwargs["tokenizer"] = tokenizer
+        prepare_model_dir(**kwargs)
     
-    # build frontend if frontend is none None
-    frontend = kwargs.get("frontend", None)
-    if frontend is not None:
-        frontend_class = tables.frontend_classes.get(frontend)
-        frontend = frontend_class(**kwargs["frontend_conf"])
-        kwargs["frontend"] = frontend
-        kwargs["input_size"] = frontend.output_size()
-
-
-    # build model
-    model_class = tables.model_classes.get(kwargs["model"])
-    vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
-    vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
-    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-
-
-
-    # init_param
-    init_param = kwargs.get("init_param", None)
-    if init_param is not None:
-        if not isinstance(init_param, (list, tuple)):
-            init_param = (init_param,)
-        logging.info("init_param is not None: %s", init_param)
-        for p in init_param:
-            if os.path.exists(p):
-                logging.info(f"Loading pretrained params from {p}")
-                load_pretrained_model(
-                    model=model,
-                    path=p,
-                    ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
-                    oss_bucket=kwargs.get("oss_bucket", None),
-                    scope_map=kwargs.get("scope_map", []),
-                    excludes=kwargs.get("excludes", None),
-                )
-            else:
-                logging.info(f"Checkpoint does not exist, init randomly: {p}")
-    elif kwargs.get("init", None):
-        initialize(model, kwargs.get("init", "kaiming_normal"))
-    else:
-        print("No initialize method")
-
+    # parse kwargs
+    kwargs = model.kwargs
+    kwargs["device"] = device
+    tokenizer = kwargs["tokenizer"]
+    frontend = kwargs["frontend"]
+    model = model.model
+    del kwargs["model"]
 
     # freeze_param
     freeze_param = kwargs.get("freeze_param", None)
@@ -135,18 +104,42 @@
         model = DDP(model, device_ids=[local_rank],
                     find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
     elif use_fsdp:
-        model = FSDP(model).cuda(local_rank)
+        # model = FSDP(model).cuda(local_rank)
+
+        def custom_auto_wrap_policy(
+            module: nn.Module,
+            recurse: bool,
+            nonwrapped_numel: int,
+            # Additional custom arguments
+            min_num_params: int = int(1e8),
+        ) -> bool:
+            # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+            is_large = unwrapped_params >= min_num_params
+            requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+            return is_large and requires_grad_uniform
+
+        # Configure a custom `min_num_params`
+        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+        torch.cuda.set_device(local_rank)
+        model = FSDP(model,
+                     auto_wrap_policy=custom_auto_wrap_policy,
+                     mixed_precision=None,
+                     device_id=torch.cuda.current_device())
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
-        
+
+    logging.info(f"{model}")
+    kwargs["device"] = next(model.parameters()).device
         
     # optim
+    logging.info("Build optim")
     optim = kwargs.get("optim", "adam")
     assert optim in optim_classes
     optim_class = optim_classes.get(optim)
     optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
     
     # scheduler
+    logging.info("Build scheduler")
     scheduler = kwargs.get("scheduler", "warmuplr")
     assert scheduler in scheduler_classes
     scheduler_class = scheduler_classes.get(scheduler)
@@ -154,45 +147,75 @@
 
 
     # dataset
-    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
-    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
+    logging.info("Build dataloader")
+    dataloader_class = tables.dataloader_classes.get( kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
+    dataloader_tr, dataloader_val = dataloader_class(**kwargs)
 
-    # dataloader
-    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-    batch_sampler_val = None
-    if batch_sampler is not None:
-        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
-        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
-        batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
-    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
-                                                collate_fn=dataset_tr.collator,
-                                                batch_sampler=batch_sampler,
-                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
-                                                pin_memory=True)
-    
-    dataloader_val = torch.utils.data.DataLoader(dataset_val,
-                                                collate_fn=dataset_val.collator,
-                                                batch_sampler=batch_sampler_val,
-                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
-                                                pin_memory=True)
-    trainer = Trainer(
-        model=model,
-        optim=optim,
-        scheduler=scheduler,
-        dataloader_train=dataloader_tr,
-        dataloader_val=dataloader_val,
-        local_rank=local_rank,
-        use_ddp=use_ddp,
-        use_fsdp=use_fsdp,
-        output_dir=kwargs.get("output_dir", "./exp"),
-        resume=kwargs.get("resume", True),
-        **kwargs.get("train_conf"),
-    )
-    trainer.run()
-    
+    trainer = Trainer(local_rank=local_rank,
+                      use_ddp=use_ddp,
+                      use_fsdp=use_fsdp,
+                      device=kwargs["device"],
+                      output_dir=kwargs.get("output_dir", "./exp"),
+                      **kwargs.get("train_conf"),
+                      )
+
+    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
+    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
+
+    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
+    os.makedirs(tensorboard_dir, exist_ok=True)
+    try:
+        from tensorboardX import SummaryWriter
+        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
+    except:
+        writer = None
+
     if use_ddp or use_fsdp:
-        torch.distributed.destroy_process_group()
+        context = Join([model])
+    else:
+        context = nullcontext()
+
+    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
+        time1 = time.perf_counter()
+        with context:
+            
+            trainer.train_epoch(
+                                model=model,
+                                optim=optim,
+                                scheduler=scheduler,
+                                scaler=scaler,
+                                dataloader_train=dataloader_tr,
+                                dataloader_val=dataloader_val,
+                                epoch=epoch,
+                                writer=writer
+                                )
+        scheduler.step()
+        trainer.validate_epoch(
+            model=model,
+            dataloader_val=dataloader_val,
+            epoch=epoch,
+            writer=writer
+        )
+
+        
+        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+
+        time2 = time.perf_counter()
+        time_escaped = (time2 - time1) / 3600.0
+        logging.info(
+            f"rank: {local_rank}, "
+            f"time_escaped_epoch: {time_escaped:.3f} hours, "
+            f"estimated to finish {trainer.max_epoch} "
+            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
+
+
+    if trainer.rank == 0:
+        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
+
+    trainer.close()
+
 
     
 
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
deleted file mode 100644
index 89f5db3..0000000
--- a/funasr/bin/train_llm.py
+++ /dev/null
@@ -1,241 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-
-import os
-import sys
-import torch
-import torch.nn as nn
-import hydra
-import logging
-import time
-import argparse
-from io import BytesIO
-
-from contextlib import nullcontext
-import torch.distributed as dist
-from collections.abc import Sequence
-from omegaconf import DictConfig, OmegaConf
-from torch.cuda.amp import autocast, GradScaler
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.algorithms.join import Join
-from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-from funasr.train_utils.average_nbest_models import average_checkpoints
-
-from funasr.register import tables
-from funasr.optimizers import optim_classes
-from funasr.train_utils.trainer_llm import Trainer
-from funasr.schedulers import scheduler_classes
-from funasr.train_utils.initialize import initialize
-from funasr.download.download_from_hub import download_model
-from funasr.models.lora.utils import mark_only_lora_as_trainable
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
-# from funasr.tokenizer.build_tokenizer import build_tokenizer
-# from funasr.tokenizer.token_id_converter import TokenIDConverter
-# from funasr.tokenizer.funtoken import build_tokenizer
-from funasr import AutoModel
-
-@hydra.main(config_name=None, version_base=None)
-def main_hydra(kwargs: DictConfig):
-    if kwargs.get("debug", False):
-        import pdb; pdb.set_trace()
-
-    assert "model" in kwargs
-    if "model_conf" not in kwargs:
-        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
-        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
-    
-
-    main(**kwargs)
-
-
-def main(**kwargs):
-    
-    # set random seed
-    set_all_random_seed(kwargs.get("seed", 0))
-    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
-    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
-    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
-    
-    local_rank = int(os.environ.get('LOCAL_RANK', 0))
-    if local_rank == 0:
-        tables.print()
-    # Check if we are using DDP or FSDP
-    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
-    use_fsdp = kwargs.get("use_fsdp", False)
-    # use_ddp = False if use_fsdp else use_fsdp
-    if use_ddp or use_fsdp:
-        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
-        torch.cuda.set_device(local_rank)
-
-    logging.info("Build model, frontend, tokenizer")
-    device = kwargs.get("device", "cuda")
-    kwargs["device"] = "cpu"
-    model = AutoModel(**kwargs)
-    
-    
-    # save config.yaml
-    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
-        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
-        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
-        OmegaConf.save(config=kwargs, f=yaml_file)
-        print(kwargs)
-        logging.info("config.yaml is saved to: %s", yaml_file)
-    
-    # parse kwargs
-    kwargs = model.kwargs
-    kwargs["device"] = device
-    tokenizer = kwargs["tokenizer"]
-    frontend = kwargs["frontend"]
-    model = model.model
-    del kwargs["model"]
-
-    # freeze_param
-    freeze_param = kwargs.get("freeze_param", None)
-    if freeze_param is not None:
-        freeze_param = eval(freeze_param)
-        if isinstance(freeze_param, Sequence):
-            freeze_param = (freeze_param,)
-        logging.info("freeze_param is not None: %s", freeze_param)
-        for t in freeze_param:
-            for k, p in model.named_parameters():
-                if k.startswith(t + ".") or k == t:
-                    logging.info(f"Setting {k}.requires_grad = False")
-                    p.requires_grad = False
-    
-
-    if use_ddp:
-        model = model.cuda(local_rank)
-        model = DDP(model, device_ids=[local_rank],
-                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
-    elif use_fsdp:
-        # model = FSDP(model).cuda(local_rank)
-
-        def custom_auto_wrap_policy(
-            module: nn.Module,
-            recurse: bool,
-            nonwrapped_numel: int,
-            # Additional custom arguments
-            min_num_params: int = int(1e8),
-        ) -> bool:
-            # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
-            is_large = unwrapped_params >= min_num_params
-            requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
-            return is_large and requires_grad_uniform
-
-        # Configure a custom `min_num_params`
-        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
-        torch.cuda.set_device(local_rank)
-        model = FSDP(model,
-                     auto_wrap_policy=custom_auto_wrap_policy,
-                     mixed_precision=None,
-                     device_id=torch.cuda.current_device())
-    else:
-        model = model.to(device=kwargs.get("device", "cuda"))
-
-    logging.info(f"{model}")
-    kwargs["device"] = next(model.parameters()).device
-        
-    # optim
-    logging.info("Build optim")
-    optim = kwargs.get("optim", "adam")
-    assert optim in optim_classes
-    optim_class = optim_classes.get(optim)
-    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
-    
-    # scheduler
-    logging.info("Build scheduler")
-    scheduler = kwargs.get("scheduler", "warmuplr")
-    assert scheduler in scheduler_classes
-    scheduler_class = scheduler_classes.get(scheduler)
-    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
-
-
-    # dataset
-    logging.info("Build dataloader")
-    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
-    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
-
-    # dataloader
-    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-    batch_sampler_val = None
-    if batch_sampler is not None:
-        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
-        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
-        batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
-        
-    dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, **batch_sampler)
-    dataloader_val = torch.utils.data.DataLoader(dataset_val, collate_fn=dataset_val.collator, **batch_sampler_val)
-
-    trainer = Trainer(local_rank=local_rank,
-                      use_ddp=use_ddp,
-                      use_fsdp=use_fsdp,
-                      device=kwargs["device"],
-                      output_dir=kwargs.get("output_dir", "./exp"),
-                      **kwargs.get("train_conf"),
-                      )
-
-    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
-    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
-
-    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
-    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
-    os.makedirs(tensorboard_dir, exist_ok=True)
-    try:
-        from tensorboardX import SummaryWriter
-        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
-    except:
-        writer = None
-
-    if use_ddp or use_fsdp:
-        context = Join([model])
-    else:
-        context = nullcontext()
-
-    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
-        time1 = time.perf_counter()
-        with context:
-            
-            trainer.train_epoch(
-                                model=model,
-                                optim=optim,
-                                scheduler=scheduler,
-                                scaler=scaler,
-                                dataloader_train=dataloader_tr,
-                                dataloader_val=dataloader_val,
-                                epoch=epoch,
-                                writer=writer
-                                )
-        scheduler.step()
-        trainer.validate_epoch(
-            model=model,
-            dataloader_val=dataloader_val,
-            epoch=epoch,
-            writer=writer
-        )
-
-        
-        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
-        time2 = time.perf_counter()
-        time_escaped = (time2 - time1) / 3600.0
-        logging.info(
-            f"rank: {local_rank}, "
-            f"time_escaped_epoch: {time_escaped:.3f} hours, "
-            f"estimated to finish {trainer.max_epoch} "
-            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
-
-
-    if trainer.rank == 0:
-        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
-
-    trainer.close()
-
-
-    
-
-if __name__ == "__main__":
-    main_hydra()
\ No newline at end of file
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index 914e776..a0ff4b6 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -1,277 +1,327 @@
 import torch
 import numpy as np
 import logging
+import math
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
+from torch.utils.data import BatchSampler, Sampler
 import torch.distributed as dist
 
 from funasr.register import tables
 
 
-@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
-class BatchSampler(torch.utils.data.BatchSampler):
-    
-    def __init__(self, dataset,
-                 batch_type: str = "example",
-                 batch_size: int = 100,
-                 buffer_size: int = 30,
-                 drop_last: bool = False,
-                 shuffle: bool = True,
-                 is_training: bool = True,
-                 **kwargs):
-        
-        self.drop_last = drop_last
-        self.pre_idx = -1
-        self.dataset = dataset
-        self.total_samples = len(dataset)
-        self.batch_type = batch_type
-        self.batch_size = int(batch_size)
-        self.buffer_size = buffer_size
-        self.max_token_length = kwargs.get("max_token_length", 5000)
-        self.shuffle_idx = np.arange(self.total_samples)
-        self.shuffle = shuffle and is_training
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        
-    
-    def __len__(self):
-        return (self.total_samples-1) // self.batch_size + 1
-    
-    def set_epoch(self, epoch):
-        np.random.seed(epoch)
-    
-    def __iter__(self):
-        
-        if self.shuffle:
-            np.random.shuffle(self.shuffle_idx)
-        
-        batch = []
-        max_token = 0
-        num_sample = 0
-        
-        iter_num = (self.total_samples - 1) // self.buffer_size + 1
-        # print("iter_num: ", iter_num)
-        for iter in range(self.pre_idx + 1, iter_num):
-            datalen_with_index = []
-            for i in range(self.buffer_size):
-                idx = iter * self.buffer_size + i
-                if idx >= self.total_samples:
-                    continue
-                
-                idx_map = self.shuffle_idx[idx]
-                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
-                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
-                sample_len_cur = source_len + target_len
-                
-                
-                datalen_with_index.append([idx, sample_len_cur])
-            
-            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-            for item in datalen_with_index_sort:
-                idx, sample_len_cur_raw = item
-                if sample_len_cur_raw > self.max_token_length:
-                    continue
-                
-                max_token_cur = max(max_token, sample_len_cur_raw)
-                max_token_padding = 1 + num_sample
-                if self.batch_type != 'example':
-                    max_token_padding *= max_token_cur
-                if max_token_padding <= self.batch_size:
-                    batch.append(idx)
-                    max_token = max_token_cur
-                    num_sample += 1
-                else:
-                    yield batch
-                    batch = [idx]
-                    max_token = sample_len_cur_raw
-                    num_sample = 1
-
-
 @tables.register("batch_sampler_classes", "BatchSampler")
+@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
+@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler")
+@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
 @tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
-class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
-    
-    def __init__(self, dataset,
-                 batch_type: str = "example",
-                 batch_size: int = 100,
-                 buffer_size: int = 30,
-                 drop_last: bool = True,
-                 shuffle: bool = True,
-                 is_training: bool = True,
-                 **kwargs):
-        
-        self.drop_last = drop_last
-        self.pre_idx = -1
-        self.dataset = dataset
-        self.total_samples = len(dataset)
-        self.batch_type = batch_type
-        self.batch_size = int(batch_size)
-        self.buffer_size = buffer_size
-        self.max_token_length = kwargs.get("max_token_length", 1500)
-        self.shuffle_idx = np.arange(self.total_samples)
-        self.shuffle = shuffle and is_training
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        
-        try:
-            rank = dist.get_rank()
-            world_size = dist.get_world_size()
-        except:
-            rank = 0
-            world_size = 1
-        self.rank = rank
-        self.world_size = world_size
-        
-    def __len__(self):
-        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-    
-    def set_epoch(self, epoch):
-        np.random.seed(epoch)
-    
-    def __iter__(self):
-    
-        batch_size_total = self.batch_size * self.world_size
-        
-        if self.shuffle:
-            np.random.shuffle(self.shuffle_idx)
-        
-        batch = []
-        max_token = 0
-        num_sample = 0
-        
-        iter_num = (self.total_samples - 1) // self.buffer_size + 1
-        # print("iter_num: ", iter_num)
-        for iter in range(self.pre_idx + 1, iter_num):
-            # if iter == iter_num -1 and self.drop_last:
-            #     continue
-            datalen_with_index = []
-            for i in range(self.buffer_size):
-                idx = iter * self.buffer_size + i
-                if idx >= self.total_samples:
-                    continue
-                
-                idx_map = self.shuffle_idx[idx]
-                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-                
-                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
-                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
-                sample_len_cur = source_len + target_len
-                
-                datalen_with_index.append([idx, sample_len_cur])
-            
-            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-            for item in datalen_with_index_sort:
-                idx, sample_len_cur_raw = item
-                if sample_len_cur_raw > self.max_token_length:
-                    continue
-
-                max_token_cur = max(max_token, sample_len_cur_raw)
-                max_token_padding = 1 + num_sample
-                # if self.batch_type != 'example':
-                #     max_token_padding *= max_token_cur
-                if max_token_padding <= batch_size_total:
-                    batch.append(idx)
-                    max_token = max_token_cur
-                    num_sample += 1
-                else:
-                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
-                    yield batch_rank
-                    batch = [idx]
-                    max_token = sample_len_cur_raw
-                    num_sample = 1
-
-
 @tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
-class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
-    
-    def __init__(self, dataset,
-                 batch_type: str = "example",
-                 batch_size: int = 100,
-                 buffer_size: int = 30,
-                 drop_last: bool = True,
-                 shuffle: bool = True,
-                 is_training: bool = True,
-                 **kwargs):
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+    dataloader_args = {}
+    batch_type = kwargs.get("batch_type", "example")
+    if batch_type == "example":
+        batch_sampler = CustomDistributedBatchSampler(dataset, **kwargs)
         
-        self.drop_last = drop_last
-        self.pre_idx = -1
+    else:
+        batch_sampler = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
+        
+    dataloader_args["batch_sampler"] = batch_sampler
+    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+    
+    return dataloader_args
+
+class CustomDistributedBatchSampler(Sampler):
+    def __init__(self, dataset,
+                 batch_size,
+                 num_replicas=None,
+                 rank=None,
+                 shuffle=True,
+                 drop_last=False,
+                 is_training: bool = True,
+                 **kwargs,
+                 ):
+
+        try:
+            rank = dist.get_rank()
+            num_replicas = dist.get_world_size()
+        except:
+            rank = 0
+            num_replicas = 1
+        self.rank = rank
+        self.num_replicas = num_replicas
         self.dataset = dataset
-        self.total_samples = len(dataset)
-        self.batch_type = batch_type
-        self.batch_size = int(batch_size)
-        self.buffer_size = buffer_size
-        self.max_token_length = kwargs.get("max_token_length", 1500)
-        self.shuffle_idx = np.arange(self.total_samples)
+        self.batch_size = batch_size
+        self.is_training = is_training
         self.shuffle = shuffle and is_training
+        self.drop_last = drop_last
+        # self.total_size = len(dataset)
+        if self.drop_last:
+            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
+        else:
+            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
+        self.num_samples = int(self.total_size // self.num_replicas)
+        self.epoch = 0
+        self.max_token_length = kwargs.get("max_token_length", None)
         self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+
+    def __iter__(self):
+        # Generate a list of indices
+        if self.shuffle:
+            g = torch.Generator()
+            g.manual_seed(self.epoch)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+
+        # Add extra samples to make it evenly divisible
+        padding_size = self.total_size - len(indices)
+        if padding_size <= len(indices):
+            indices += indices[:padding_size]
+        else:
+            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
+
+        assert len(indices) == self.total_size
+
+        # Subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        # Filter out indices with length greater than the max length, if provided
+        if self.max_token_length is not None:
+            filtered_indices = []
+            for idx in indices:
+                source_len = self.dataset.get_source_len(idx) / self.length_scale_source
+                if source_len <= self.max_token_length:
+                    filtered_indices.append(idx)
+            indices = filtered_indices
+
+        # Now that we have only the indices for this replica, chunk them into batches
+        batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
+
+        # Drop the last batch if it's not full and drop_last is True
+        if self.drop_last and len(batches[-1]) != self.batch_size:
+            batches = batches[:-1]
+
+        return iter(batches)
+
+    def __len__(self):
+
+        return self.num_samples // self.batch_size
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+class CustomDistributedBufferBatchSampler(Sampler):
+    def __init__(self, dataset,
+                 batch_size,
+                 num_replicas=None,
+                 rank=None,
+                 shuffle=True,
+                 drop_last=False,
+                 is_training: bool = True,
+                 sort_size: int = 1024,
+                 **kwargs,
+                 ):
         
         try:
             rank = dist.get_rank()
-            world_size = dist.get_world_size()
+            num_replicas = dist.get_world_size()
         except:
             rank = 0
-            world_size = 1
+            num_replicas = 1
         self.rank = rank
-        self.world_size = world_size
-    
-    def __len__(self):
-        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-    
-    def set_epoch(self, epoch):
-        np.random.seed(epoch)
+        self.num_replicas = num_replicas
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.is_training = is_training
+        self.shuffle = shuffle and is_training
+        self.drop_last = drop_last
+        # self.total_size = len(dataset)
+        if self.drop_last:
+            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
+        else:
+            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
+        self.num_samples = int(self.total_size // self.num_replicas)
+        self.epoch = 0
+        self.max_token_length = kwargs.get("max_token_length", None)
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        self.sort_size = sort_size
     
     def __iter__(self):
-        
-        batch_size_total = self.batch_size * self.world_size
+        # Generate a list of indices
         if self.shuffle:
-            np.random.shuffle(self.shuffle_idx)
+            g = torch.Generator()
+            g.manual_seed(self.epoch)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
         
-        batch_list_all_rank = []
-        batch_list_cur = []
-        max_token = 0
-        num_sample = 0
+        # Add extra samples to make it evenly divisible
+        padding_size = self.total_size - len(indices)
+        if padding_size <= len(indices):
+            indices += indices[:padding_size]
+        else:
+            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
         
-        iter_num = (self.total_samples - 1) // self.buffer_size + 1
-        # print("iter_num: ", iter_num)
-        for iter in range(self.pre_idx + 1, iter_num):
-            # if iter == iter_num - 1 and self.drop_last:
-            #     continue
-            datalen_with_index = []
-            for i in range(self.buffer_size):
-                idx = iter * self.buffer_size + i
-                if idx >= self.total_samples:
-                    continue
-                
-                idx_map = self.shuffle_idx[idx]
-                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-                
-                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
-                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
-                sample_len_cur = source_len + target_len
-                
-                datalen_with_index.append([idx, sample_len_cur])
+        assert len(indices) == self.total_size
+        
+        # Subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+        
+        # Filter out indices with length greater than the max length, if provided
+        if self.max_token_length is not None:
+            filtered_indices = []
+            for idx in indices:
+                source_len = self.dataset.get_source_len(idx) / self.length_scale_source
+                if source_len <= self.max_token_length:
+                    filtered_indices.append(idx)
+            indices = filtered_indices
+
+        # Buffer sorting logic
+        sorted_batches = []
+        buffer = []
+
+        for idx in indices:
+            buffer.append(idx)
+            if len(buffer) >= self.sort_size:
+                # Sort the buffer based on some criteria, e.g., dataset sample length
+                buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+                sorted_batches.extend(self._create_batches_from_buffer(buffer))
+                buffer = []
+
+        # Handle the remaining items in the buffer
+        if buffer:
+            buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+            sorted_batches.extend(self._create_batches_from_buffer(buffer))
+
+        return iter(sorted_batches)
+
+    def _create_batches_from_buffer(self, buffer):
+        # Function to convert the sorted buffer into batches
+        batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
+        if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
+            batched_buffer = batched_buffer[:-1]
+        return batched_buffer
+    
+    def __len__(self):
+        
+        return self.num_samples // self.batch_size
+    
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+class CustomDistributedDynamicBatchSampler(Sampler):
+    def __init__(self, dataset,
+                 batch_size,
+                 num_replicas=None,
+                 rank=None,
+                 shuffle=True,
+                 drop_last=False,
+                 is_training: bool = True,
+                 **kwargs,
+                 ):
+        
+        try:
+            rank = dist.get_rank()
+            num_replicas = dist.get_world_size()
+        except:
+            rank = 0
+            num_replicas = 1
+        self.rank = rank
+        self.num_replicas = num_replicas
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.is_training = is_training
+        self.shuffle = shuffle and is_training
+        self.drop_last = drop_last
+        
+        self.total_size = len(self.dataset)
+        # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
+        self.epoch = 0
+    
+    def __iter__(self):
+        if self.shuffle:
+            g = torch.Generator()
+            g.manual_seed(self.epoch)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+        
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        
+        batches = []
+        batch = []
+        max_len_in_batch = 0
+        current_batch_length = 0
+        
+        for idx in indices:
+            sample_length = self.dataset.get_source_len(idx)
+            potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
+                    len(batch) + 1)
             
-            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-            for ii, item in enumerate(datalen_with_index_sort):
-                is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
-                idx, sample_len_cur_raw = item
-                if sample_len_cur_raw > self.max_token_length:
-                    continue
-                
-                max_token_cur = max(max_token, sample_len_cur_raw)
-                max_token_padding = 1 + num_sample
-                
-                if self.batch_type != 'example':
-                    max_token_padding *= max_token_cur
-                if len(batch_list_all_rank) < self.world_size:
-                    
-                    if max_token_padding <= self.batch_size:
-                        batch_list_cur.append(idx)
-                        max_token = max_token_cur
-                        num_sample += 1
-                    else:
-                        batch_list_all_rank.append(batch_list_cur)
-                        batch_list_cur = []
-                else:
-                    batch_rank = batch_list_all_rank[self.rank]
-                    yield batch_rank
-                    batch_list_all_rank = [idx]
-                    max_token = sample_len_cur_raw
-                    num_sample = 1
+            if potential_batch_length <= self.batch_size:
+                batch.append(idx)
+                if sample_length > max_len_in_batch:
+                    max_len_in_batch = sample_length
+                    current_batch_length = max_len_in_batch * len(batch)
+            else:
+                batches.append(batch)
+                batch = [idx]
+                max_len_in_batch = sample_length
+                current_batch_length = max_len_in_batch
+        
+        # Add the last batch if it's not empty and we're not dropping it
+        if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
+            batches.append(batch)
+        
+        return iter(batches)
+    
+    def __len__(self):
+        
+        return 1
+    
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+
+class DistributedSamplerWarp(BatchSampler):
+    def __init__(self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False):
+        if num_replicas is None:
+            if not torch.distributed.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = torch.distributed.get_world_size()
+        if rank is None:
+            if not torch.distributed.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = torch.distributed.get_rank()
+        
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.shuffle = shuffle
+        self.drop_last = drop_last
+        
+        # Create an instance of the DistributedSampler
+        self.sampler = DistributedSampler(
+            self.dataset,
+            num_replicas=self.num_replicas,
+            rank=self.rank,
+            shuffle=self.shuffle
+        )
+        
+        # Call BatchSampler's constructor
+        super().__init__(self.sampler, batch_size, drop_last)
+    
+    def __iter__(self):
+        # If we shuffle, we need to call the set_epoch method
+        if self.shuffle:
+            self.sampler.set_epoch(self.epoch)
+        
+        # Generate batch indices using the parent class
+        return super().__iter__()
+    
+    def set_epoch(self, epoch):
+        self.epoch = epoch
diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py
new file mode 100644
index 0000000..21e3834
--- /dev/null
+++ b/funasr/datasets/dataloader_entry.py
@@ -0,0 +1,38 @@
+
+import logging
+import torch
+
+from funasr.register import tables
+
+@tables.register("dataloader_classes", "DataloaderMapStyle")
+def DataloaderMapStyle(frontend=None, tokenizer=None, **kwargs):
+	# dataset
+	logging.info("Build dataloader")
+	dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
+	dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
+	dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
+	
+	# dataloader
+	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+	batch_sampler_val = None
+	if batch_sampler is not None:
+		batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+		batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+		batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
+	
+	dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, **batch_sampler)
+	dataloader_val = torch.utils.data.DataLoader(dataset_val, collate_fn=dataset_val.collator, **batch_sampler_val)
+	
+	return dataloader_tr, dataloader_val
+
+
+@tables.register("dataloader_classes", "DataloaderIterable")
+def DataloaderIterable(frontend=None, tokenizer=None, **kwargs):
+	logging.info("Build dataloader")
+	dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "LargeDataset"))
+	dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer,
+	                           is_training=True, **kwargs.get("dataset_conf"))
+	dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
+	                            is_training=False, **kwargs.get("dataset_conf"))
+	
+	return dataset_tr, dataset_val
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/__init__.py b/funasr/datasets/large_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/large_datasets/__init__.py
diff --git a/funasr/datasets/large_datasets/abs_iter_factory.py b/funasr/datasets/large_datasets/abs_iter_factory.py
new file mode 100644
index 0000000..36e4dd2
--- /dev/null
+++ b/funasr/datasets/large_datasets/abs_iter_factory.py
@@ -0,0 +1,9 @@
+from abc import ABC
+from abc import abstractmethod
+from typing import Iterator
+
+
+class AbsIterFactory(ABC):
+    @abstractmethod
+    def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
+        raise NotImplementedError
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
new file mode 100644
index 0000000..8a255f9
--- /dev/null
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -0,0 +1,97 @@
+import logging
+from pathlib import Path
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import sentencepiece as spm
+from torch.utils.data import DataLoader
+
+from funasr.datasets.large_datasets.dataset import Dataset
+from funasr.datasets.large_datasets.abs_iter_factory import AbsIterFactory
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
+
+from funasr.register import tables
+
+def read_symbol_table(symbol_table_file):
+    if isinstance(symbol_table_file, str):
+        symbol_table = {}
+        with open(symbol_table_file, "r", encoding="utf8") as fin:
+            for i, line in enumerate(fin):
+                char = line.strip()
+                symbol_table[char] = i
+    else:
+        assert isinstance(symbol_table_file, list)
+        symbol_table = {}
+        for i, char in enumerate(symbol_table_file):
+            symbol_table[char] = i
+    return symbol_table
+
+
+def load_seg_dict(seg_dict_file):
+    seg_dict = {}
+    assert isinstance(seg_dict_file, str)
+    with open(seg_dict_file, "r", encoding="utf8") as f:
+        lines = f.readlines()
+        for line in lines:
+            s = line.strip().split()
+            key = s[0]
+            value = s[1:]
+            seg_dict[key] = " ".join(value)
+    return seg_dict
+
+
+class SentencepiecesTokenizer(AbsTokenizer):
+    def __init__(self, model: Union[Path, str]):
+        self.model = str(model)
+        self.sp = None
+
+    def __repr__(self):
+        return f'{self.__class__.__name__}(model="{self.model}")'
+
+    def _build_sentence_piece_processor(self):
+        if self.sp is None:
+            self.sp = spm.SentencePieceProcessor()
+            self.sp.load(self.model)
+
+    def text2tokens(self, line: str) -> List[str]:
+        self._build_sentence_piece_processor()
+        return self.sp.EncodeAsPieces(line)
+
+    def tokens2text(self, tokens: Iterable[str]) -> str:
+        self._build_sentence_piece_processor()
+        return self.sp.DecodePieces(list(tokens))
+
+@tables.register("dataset_classes", "LargeDataset")
+class LargeDataLoader(AbsIterFactory):
+    def __init__(self, args, mode="train"):
+        symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
+        if hasattr(args, "token_list") and args.token_list is not None:
+            symbol_table = read_symbol_table(args.token_list)
+        if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
+            seg_dict = load_seg_dict(args.seg_dict_file)
+        if hasattr(args, "punc_list") and args.punc_list is not None:
+            punc_dict = read_symbol_table(args.punc_list)
+        if hasattr(args, "bpemodel") and args.bpemodel is not None:
+            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
+        self.dataset_conf = args.dataset_conf
+        if "frontend_conf" not in args:
+            self.frontend_conf =  None
+        else:
+            self.frontend_conf = args.frontend_conf
+        self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None 
+        logging.info("dataloader config: {}".format(self.dataset_conf))
+        batch_mode = self.dataset_conf.get("batch_mode", "padding")
+        data_list = args.train_data_file if mode == "train" else args.valid_data_file
+        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
+                               self.dataset_conf, self.frontend_conf,
+                               speed_perturb=self.speed_perturb if mode == "train" else None,
+                               mode=mode, batch_mode=batch_mode)
+
+    def build_iter(self, epoch, shuffle=True):
+        self.dataset.set_epoch(epoch)
+        data_loader = DataLoader(self.dataset,
+                                 batch_size=None,
+                                 pin_memory=True,
+                                 num_workers=self.dataset_conf.get("num_workers", 8))
+        return data_loader
diff --git a/funasr/datasets/large_datasets/collate_fn.py b/funasr/datasets/large_datasets/collate_fn.py
new file mode 100644
index 0000000..aff25a8
--- /dev/null
+++ b/funasr/datasets/large_datasets/collate_fn.py
@@ -0,0 +1,196 @@
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from funasr.models.transformer.utils.nets_utils import pad_list, pad_list_all_dim
+
+
+class CommonCollateFn:
+    """Functor class of common_collate_fn()"""
+
+    def __init__(
+            self,
+            float_pad_value: Union[float, int] = 0.0,
+            int_pad_value: int = -32768,
+            not_sequence: Collection[str] = (),
+            max_sample_size=None
+    ):
+        self.float_pad_value = float_pad_value
+        self.int_pad_value = int_pad_value
+        self.not_sequence = set(not_sequence)
+        self.max_sample_size = max_sample_size
+
+    def __repr__(self):
+        return (
+            f"{self.__class__}(float_pad_value={self.float_pad_value}, "
+            f"int_pad_value={self.float_pad_value})"
+        )
+
+    def __call__(
+            self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
+    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+        return common_collate_fn(
+            data,
+            float_pad_value=self.float_pad_value,
+            int_pad_value=self.int_pad_value,
+            not_sequence=self.not_sequence,
+        )
+
+
+def common_collate_fn(
+        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+        float_pad_value: Union[float, int] = 0.0,
+        int_pad_value: int = -32768,
+        not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+    """Concatenate ndarray-list to an array and convert to torch.Tensor.
+    """
+    uttids = [u for u, _ in data]
+    data = [d for _, d in data]
+
+    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+    assert all(
+        not k.endswith("_lengths") for k in data[0]
+    ), f"*_lengths is reserved: {list(data[0])}"
+
+    output = {}
+    for key in data[0]:
+        if data[0][key].dtype.kind == "i":
+            pad_value = int_pad_value
+        else:
+            pad_value = float_pad_value
+
+        array_list = [d[key] for d in data]
+        tensor_list = [torch.from_numpy(a) for a in array_list]
+        tensor = pad_list(tensor_list, pad_value)
+        output[key] = tensor
+
+        if key not in not_sequence:
+            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
+            output[key + "_lengths"] = lens
+
+    output = (uttids, output)
+    return output
+
+
+class DiarCollateFn:
+    """Functor class of common_collate_fn()"""
+
+    def __init__(
+            self,
+            float_pad_value: Union[float, int] = 0.0,
+            int_pad_value: int = -32768,
+            not_sequence: Collection[str] = (),
+            max_sample_size=None
+    ):
+        self.float_pad_value = float_pad_value
+        self.int_pad_value = int_pad_value
+        self.not_sequence = set(not_sequence)
+        self.max_sample_size = max_sample_size
+
+    def __repr__(self):
+        return (
+            f"{self.__class__}(float_pad_value={self.float_pad_value}, "
+            f"int_pad_value={self.float_pad_value})"
+        )
+
+    def __call__(
+            self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
+    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+        return diar_collate_fn(
+            data,
+            float_pad_value=self.float_pad_value,
+            int_pad_value=self.int_pad_value,
+            not_sequence=self.not_sequence,
+        )
+
+
+def diar_collate_fn(
+        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+        float_pad_value: Union[float, int] = 0.0,
+        int_pad_value: int = -32768,
+        not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+    """Concatenate ndarray-list to an array and convert to torch.Tensor.
+    """
+    uttids = [u for u, _ in data]
+    data = [d for _, d in data]
+
+    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+    assert all(
+        not k.endswith("_lengths") for k in data[0]
+    ), f"*_lengths is reserved: {list(data[0])}"
+
+    output = {}
+    for key in data[0]:
+        if data[0][key].dtype.kind == "i":
+            pad_value = int_pad_value
+        else:
+            pad_value = float_pad_value
+
+        array_list = [d[key] for d in data]
+        tensor_list = [torch.from_numpy(a) for a in array_list]
+        tensor = pad_list_all_dim(tensor_list, pad_value)
+        output[key] = tensor
+
+        if key not in not_sequence:
+            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
+            output[key + "_lengths"] = lens
+
+    output = (uttids, output)
+    return output
+
+
+def crop_to_max_size(feature, target_size):
+    size = len(feature)
+    diff = size - target_size
+    if diff <= 0:
+        return feature
+
+    start = np.random.randint(0, diff + 1)
+    end = size - diff + start
+    return feature[start:end]
+
+
+def clipping_collate_fn(
+        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+        max_sample_size=None,
+        not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+    # mainly for pre-training
+    uttids = [u for u, _ in data]
+    data = [d for _, d in data]
+
+    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+    assert all(
+        not k.endswith("_lengths") for k in data[0]
+    ), f"*_lengths is reserved: {list(data[0])}"
+
+    output = {}
+    for key in data[0]:
+        array_list = [d[key] for d in data]
+        tensor_list = [torch.from_numpy(a) for a in array_list]
+        sizes = [len(s) for s in tensor_list]
+        if max_sample_size is None:
+            target_size = min(sizes)
+        else:
+            target_size = min(min(sizes), max_sample_size)
+        tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
+        for i, (source, size) in enumerate(zip(tensor_list, sizes)):
+            diff = size - target_size
+            if diff == 0:
+                tensor[i] = source
+            else:
+                tensor[i] = crop_to_max_size(source, target_size)
+        output[key] = tensor
+
+        if key not in not_sequence:
+            lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
+            output[key + "_lengths"] = lens
+
+    output = (uttids, output)
+    return output
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/datapipes/__init__.py b/funasr/datasets/large_datasets/datapipes/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/__init__.py
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
new file mode 100644
index 0000000..35e5dba
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -0,0 +1,213 @@
+import random
+
+from itertools import count
+from functools import partial
+from torch.utils.data import IterableDataset
+from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+
+tiebreaker = count()
+
+
+def _default_len_fn(token):
+    return len(token), next(tiebreaker)
+
+
+def _token_len_fn(token, len_fn):
+    return len_fn(token), next(tiebreaker), token
+
+
+class MaxTokenBucketizerIterDataPipe(IterableDataset):
+
+    def __init__(
+            self,
+            datapipe,
+            batch_size=8000,
+            len_fn=_default_len_fn,
+            buffer_size=10240,
+            sort_size=500,
+            batch_mode="padding",
+    ):
+        assert batch_size > 0, "Batch size is required to be larger than 0!"
+        assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
+        assert sort_size > 0, "Sort size is required to be larger than 0!"
+
+        datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
+        self.datapipe = datapipe
+        self.batch_size = batch_size
+        self.buffer_size = buffer_size
+        self.sort_size = sort_size
+        self.batch_mode = batch_mode
+
+    def set_epoch(self, epoch):
+        self.datapipe.set_epoch(epoch)
+
+    def __iter__(self):
+        buffer = []
+        batch = []
+        bucket = []
+        max_lengths = 0
+        min_lengths = 999999
+        batch_lengths = 0
+
+        if self.batch_mode == "clipping":
+            assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
+            for d in self.datapipe:
+                if d[0] > self.batch_size:
+                    continue
+                buffer.append(d)
+                if len(buffer) == self.buffer_size:
+                    random.shuffle(buffer)
+                    for sample in buffer:
+                        bucket.append(sample)
+                        if len(bucket) == self.sort_size:
+                            bucket.sort()
+                            for x in bucket:
+                                length, _, token = x
+                                if length < min_lengths:
+                                    min_lengths = length
+                                batch_lengths = min_lengths * (len(batch) + 1)
+                                if batch_lengths > self.batch_size:
+                                    yield batch
+                                    batch = []
+                                    min_lengths = length
+                                batch.append(token)
+                            bucket = []
+                    buffer = []
+
+            if buffer:
+                random.shuffle(buffer)
+                for sample in buffer:
+                    bucket.append(sample)
+                    if len(bucket) == self.sort_size:
+                        bucket.sort()
+                        for x in bucket:
+                            length, _, token = x
+                            if length < min_lengths:
+                                min_lengths = length
+                            batch_lengths = min_lengths * (len(batch) + 1)
+                            if batch_lengths > self.batch_size:
+                                yield batch
+                                batch = []
+                                min_lengths = length
+                            batch.append(token)
+                        bucket = []
+                buffer = []
+
+            if bucket:
+                bucket.sort()
+                for x in bucket:
+                    length, _, token = x
+                    if length < min_lengths:
+                        min_lengths = length
+                    batch_lengths = min_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        yield batch
+                        batch = []
+                        min_lengths = length
+                    batch.append(token)
+                bucket = []
+
+            if batch:
+                yield batch
+
+        else:
+            if self.buffer_size == -1:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                buffer.sort()
+                for sample in buffer:
+                    length, _, token = sample
+                    if length > max_lengths:
+                        max_lengths = length
+                    batch_lengths = max_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        bucket.append(batch)
+                        batch = []
+                        max_lengths = length
+                    batch.append(token)
+                random.shuffle(bucket)
+                if bucket:
+                    for batch_sample in bucket:
+                        yield batch_sample
+                if batch:
+                    yield batch
+
+            elif self.buffer_size == 0:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    length, _, token = d
+                    if length > self.batch_size:
+                        continue
+                    if length > max_lengths:
+                        max_lengths = length
+                    batch_lengths = max_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        yield batch
+                        batch = []
+                        max_lengths = length
+                    batch.append(token)
+                if batch:
+                    yield batch
+
+            else:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                    if len(buffer) == self.buffer_size:
+                        random.shuffle(buffer)
+                        for sample in buffer:
+                            bucket.append(sample)
+                            if len(bucket) == self.sort_size:
+                                bucket.sort()
+                                for x in bucket:
+                                    length, _, token = x
+                                    if length > max_lengths:
+                                        max_lengths = length
+                                    batch_lengths = max_lengths * (len(batch) + 1)
+                                    if batch_lengths > self.batch_size:
+                                        yield batch
+                                        batch = []
+                                        max_lengths = length
+                                    batch.append(token)
+                                bucket = []
+                        buffer = []
+
+                if buffer:
+                    random.shuffle(buffer)
+                    for sample in buffer:
+                        bucket.append(sample)
+                        if len(bucket) == self.sort_size:
+                            bucket.sort()
+                            for x in bucket:
+                                length, _, token = x
+                                if length > max_lengths:
+                                    max_lengths = length
+                                batch_lengths = max_lengths * (len(batch) + 1)
+                                if batch_lengths > self.batch_size:
+                                    yield batch
+                                    batch = []
+                                    max_lengths = length
+                                batch.append(token)
+                            bucket = []
+                    buffer = []
+
+                if bucket:
+                    bucket.sort()
+                    for x in bucket:
+                        length, _, token = x
+                        if length > max_lengths:
+                            max_lengths = length
+                        batch_lengths = max_lengths * (len(batch) + 1)
+                        if batch_lengths > self.batch_size:
+                            yield batch
+                            batch = []
+                            max_lengths = length
+                        batch.append(token)
+                    bucket = []
+
+                if batch:
+                    yield batch
diff --git a/funasr/datasets/large_datasets/datapipes/filter.py b/funasr/datasets/large_datasets/datapipes/filter.py
new file mode 100644
index 0000000..6fe7153
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/filter.py
@@ -0,0 +1,24 @@
+from torch.utils.data import IterableDataset
+
+def default_fn(data):
+    return data
+
+
+class FilterIterDataPipe(IterableDataset):
+
+    def __init__(self,
+                 datapipe,
+                 fn=default_fn):
+        self.datapipe = datapipe
+        self.fn = fn
+
+    def set_epoch(self, epoch):
+        self.datapipe.set_epoch(epoch)
+
+    def __iter__(self):
+        assert callable(self.fn)
+        for data in self.datapipe:
+            if self.fn(data):
+                yield data
+            else:
+                continue
diff --git a/funasr/datasets/large_datasets/datapipes/map.py b/funasr/datasets/large_datasets/datapipes/map.py
new file mode 100644
index 0000000..dfcd6a0
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/map.py
@@ -0,0 +1,22 @@
+from torch.utils.data import IterableDataset
+
+
+def default_fn(data):
+    return data
+
+
+class MapperIterDataPipe(IterableDataset):
+
+    def __init__(self,
+                 datapipe,
+                 fn=default_fn):
+        self.datapipe = datapipe
+        self.fn = fn
+
+    def set_epoch(self, epoch):
+        self.datapipe.set_epoch(epoch)
+
+    def __iter__(self):
+        assert callable(self.fn)
+        for data in self.datapipe:
+            yield self.fn(data)
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
new file mode 100644
index 0000000..d3489c1
--- /dev/null
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -0,0 +1,274 @@
+import logging
+import os
+import random
+from functools import partial
+
+import torch
+import torch.distributed as dist
+import torchaudio
+import numpy as np
+# import librosa
+import librosa
+from kaldiio import ReadHelper
+from torch.utils.data import IterableDataset
+
+from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
+from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
+from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+from funasr.datasets.large_datasets.utils.clipping import clipping
+from funasr.datasets.large_datasets.utils.filter import filter
+from funasr.datasets.large_datasets.utils.padding import padding
+from funasr.datasets.large_datasets.utils.tokenize import tokenize
+
+
+def read_lists(list_file):
+    lists = []
+    with open(list_file, 'r', encoding='utf8') as fin:
+        for line in fin:
+            parts = line.strip()
+            lists.append(parts)
+    return lists
+
+
+class AudioDataset(IterableDataset):
+    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
+                 mode="train"):
+        self.scp_lists = scp_lists
+        self.data_names = data_names
+        self.data_types = data_types
+        self.frontend_conf = frontend_conf
+        self.shuffle = shuffle
+        self.mode = mode
+        self.epoch = -1
+        self.rank = 0
+        self.world_size = 1
+        self.worker_id = 0
+        self.num_workers = 1
+        self.speed_perturb = speed_perturb
+        if self.speed_perturb is not None:
+            logging.info("Using speed_perturb: {}".format(speed_perturb))
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+    def get_rank_data_list(self, data_index):
+        assert dist.is_available()
+        if dist.is_initialized():
+            self.rank = dist.get_rank()
+            self.world_size = dist.get_world_size()
+        else:
+            self.rank = 0
+            self.world_size = 1
+
+        if self.mode == "train":
+            if self.shuffle:
+                random.seed(self.epoch)
+                random.shuffle(data_index)
+            return data_index[self.rank::self.world_size]
+
+        return data_index
+
+    def get_worker_data_list(self, rank_data_index):
+        worker_info = torch.utils.data.get_worker_info()
+        if worker_info is None:
+            self.worker_id = 0
+            self.num_workers = 1
+        else:
+            self.worker_id = worker_info.id
+            self.num_workers = worker_info.num_workers
+
+        return rank_data_index[self.worker_id::self.num_workers]
+
+    def close_reader(self, reader_list):
+        for reader in reader_list:
+            reader.close()
+
+    def __iter__(self):
+        data_index = list(range(len(self.scp_lists)))
+        rank_data_index = self.get_rank_data_list(data_index)
+        worker_data_index = self.get_worker_data_list(rank_data_index)
+
+        for index in worker_data_index:
+            data = dict(scp=self.scp_lists[index])
+
+            assert 'scp' in data
+            scp = data['scp']
+            data_file_list = scp.strip().split()
+            data_name_list = self.data_names.split(",")
+            data_type_list = self.data_types.split(",")
+
+            for file in data_file_list:
+                assert os.path.exists(file), "{} not exists".format(file)
+
+            assert len(data_file_list) == len(data_name_list) == len(data_type_list), \
+                "The item number of data, data_names, data_types must be the same "
+
+            reader_list = []
+            for data_file, data_type in zip(data_file_list, data_type_list):
+                if data_type == "kaldi_ark":
+                    ark_reader = ReadHelper('ark:{}'.format(data_file))
+                    reader_list.append(ark_reader)
+                elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword':
+                    text_reader = open(data_file, "r", encoding="utf-8")
+                    reader_list.append(text_reader)
+                elif data_type == "none":
+                    continue
+                else:
+                    raise TypeError("Data type {} is not supported".format(data_type))
+
+            for items in zip(*reader_list):
+                sample_dict = {}
+                for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
+                    if data_type == "kaldi_ark":
+                        key, mat = item
+                        sample_dict[data_name] = mat
+                        if data_name == "speech":
+                            sample_dict["key"] = key
+                    elif data_type == "sound":
+                        key, path = item.strip().split()
+                        try:
+                            waveform, sampling_rate = torchaudio.load(path)
+                        except:
+                            # waveform, sampling_rate = librosa.load(path, dtype='float32')
+                            waveform, sampling_rate = librosa.load(path, dtype='float32')
+                            if waveform.ndim == 2:
+                                waveform = waveform[:, 0]
+                            waveform = np.expand_dims(waveform, axis=0)
+                            waveform = torch.tensor(waveform)
+                        if self.frontend_conf is not None:
+                            if sampling_rate != self.frontend_conf["fs"]:
+                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
+                                                                          new_freq=self.frontend_conf["fs"])(waveform)
+                                sampling_rate = self.frontend_conf["fs"]
+                        waveform = waveform.numpy()
+                        mat = waveform[0]
+                        if self.speed_perturb is not None:
+                            speed = random.choice(self.speed_perturb)
+                            if speed != 1.0:
+                                mat, _ = torchaudio.sox_effects.apply_effects_tensor(
+                                    torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
+                                mat = mat.view(-1).numpy()
+                        sample_dict[data_name] = mat
+                        sample_dict["sampling_rate"] = sampling_rate
+                        if data_name == "speech":
+                            sample_dict["key"] = key
+                    elif data_type == "text_hotword":
+                        text = item
+                        segs = text.strip().split()
+                        sample_dict[data_name] = segs[1:]
+                        if "key" not in sample_dict:
+                            sample_dict["key"] = segs[0]
+                        sample_dict['hw_tag'] = 1
+                    elif data_type == "text_nospace":
+                        text = item
+                        segs = text.strip().split(maxsplit=1)
+                        sample_dict[data_name] = [x for x in segs[1]]
+                        if "key" not in sample_dict:
+                            sample_dict["key"] = segs[0]
+                    else:
+                        text = item
+                        segs = text.strip().split()
+                        sample_dict[data_name] = segs[1:]
+                        if "key" not in sample_dict:
+                            sample_dict["key"] = segs[0]
+                yield sample_dict
+
+            self.close_reader(reader_list)
+
+
+def len_fn_example(data):
+    return 1
+
+
+def len_fn_token(data):
+    assert "speech" in data
+    if "sampling_rate" in data:
+        return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+    else:
+        return data["speech"].shape[0]
+
+
+def Dataset(data_list_file,
+            dict,
+            seg_dict,
+            punc_dict,
+            bpe_tokenizer,
+            conf,
+            frontend_conf,
+            speed_perturb=None,
+            mode="train",
+            batch_mode="padding"):
+    scp_lists = read_lists(data_list_file)
+    shuffle = conf.get('shuffle', True)
+    data_names = conf.get("data_names", "speech,text")
+    data_types = conf.get("data_types", "kaldi_ark,text")
+
+    pre_hwfile = conf.get("pre_hwlist", None)
+    # pre_prob = conf.get("pre_prob", 0)  # unused yet
+    if pre_hwfile is not None:
+        pre_hwlist = []
+        with open(pre_hwfile, 'r', encoding="utf-8") as fin:
+            for line in fin.readlines():
+                pre_hwlist.append(line.strip())
+    else:
+        pre_hwlist = None
+
+    hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
+                 "double_rate": conf.get("double_rate", 0.1),
+                 "hotword_min_length": conf.get("hotword_min_length", 2),
+                 "hotword_max_length": conf.get("hotword_max_length", 8),
+                 "pre_prob": conf.get("pre_prob", 0.0),
+                 "pre_hwlist": pre_hwlist}
+
+    
+
+    dataset = AudioDataset(scp_lists, 
+                           data_names, 
+                           data_types, 
+                           frontend_conf=frontend_conf, 
+                           shuffle=shuffle,
+                           speed_perturb=speed_perturb,
+                           mode=mode, 
+                           )
+
+    if "text" in data_names:
+        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config}
+        tokenize_fn = partial(tokenize, **vocab)
+        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
+
+    filter_conf = conf.get('filter_conf', {})
+    filter_fn = partial(filter, **filter_conf)
+    dataset = FilterIterDataPipe(dataset, fn=filter_fn)
+
+    if shuffle:
+        buffer_conf = conf.get('shuffle_conf', {})
+        buffer_size = buffer_conf['shuffle_size']
+        sort_size = buffer_conf['sort_size']
+    else:
+        buffer_size = 0
+        sort_size = 1
+
+    batch_conf = conf.get('batch_conf', {})
+    batch_size = batch_conf['batch_size']
+    batch_type = batch_conf['batch_type']
+
+    assert batch_type in ["example", "token"]
+    if batch_type == 'example':
+        len_fn = len_fn_example
+    else:
+        len_fn = len_fn_token
+
+    dataset = MaxTokenBucketizerIterDataPipe(dataset,
+                                             batch_size=batch_size,
+                                             len_fn=len_fn,
+                                             buffer_size=buffer_size,
+                                             sort_size=sort_size,
+                                             batch_mode=batch_mode)
+
+    int_pad_value = conf.get("int_pad_value", -1)
+    float_pad_value = conf.get("float_pad_value", 0.0)
+    padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
+    padding_fn = partial(padding, **padding_conf)
+    dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
+
+    return dataset
diff --git a/funasr/datasets/large_datasets/utils/__init__.py b/funasr/datasets/large_datasets/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/__init__.py
diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
new file mode 100644
index 0000000..2554aba
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/clipping.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+
+from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
+
+
+def clipping(data):
+    assert isinstance(data, list)
+    assert "key" in data[0]
+
+    keys = [x["key"] for x in data]
+
+    batch = {}
+    data_names = data[0].keys()
+    for data_name in data_names:
+        if data_name == "key":
+            continue
+        else:
+            if data[0][data_name].dtype.kind == "i":
+                tensor_type = torch.int64
+            else:
+                tensor_type = torch.float32
+
+            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+
+            length_clip = min(tensor_lengths)
+            tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+            for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
+                diff = length - length_clip
+                assert diff >= 0
+                if diff == 0:
+                    tensor_clip[i] = tensor
+                else:
+                    tensor_clip[i] = crop_to_max_size(tensor, length_clip)
+
+            batch[data_name] = tensor_clip
+            batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+
+    return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/filter.py b/funasr/datasets/large_datasets/utils/filter.py
new file mode 100644
index 0000000..1260a47
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/filter.py
@@ -0,0 +1,26 @@
+#!/usr/bin/env python
+
+
+def filter(data,
+           speech_length_min=100,
+           speech_length_max=15000,
+           token_length_min=0,
+           token_length_max=200):
+    assert "speech" in data or "text" in data
+
+    if "speech" in data and "text" in data:
+        if "sampling_rate" in data:
+            speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+        else:
+            speech_length = data["speech"].shape[0]
+        num_tokens = len(data['text'])
+        return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
+    elif "speech" in data:
+        if "sampling_rate" in data:
+            speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+        else:
+            speech_length = data["speech"].shape[0]
+        return speech_length_min < speech_length < speech_length_max
+    else:
+        num_tokens = len(data['text'])
+        return token_length_min < num_tokens < token_length_max
diff --git a/funasr/datasets/large_datasets/utils/hotword_utils.py b/funasr/datasets/large_datasets/utils/hotword_utils.py
new file mode 100644
index 0000000..73f8bdd
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/hotword_utils.py
@@ -0,0 +1,33 @@
+import random
+
+def sample_hotword(length, 
+                   hotword_min_length, 
+                   hotword_max_length,
+                   sample_rate,
+                   double_rate,
+                   pre_prob,
+                   pre_index=None,
+                   pre_hwlist=None):
+        if length < hotword_min_length:
+            return [-1]
+        if random.random() < sample_rate:
+            if pre_prob > 0 and random.random() < pre_prob and pre_index is not None:
+                return pre_index
+            if length == hotword_min_length:
+                return [0, length-1]
+            elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2:
+                # sample two hotwords in a sentence
+                _max_hw_length = min(hotword_max_length, length // 2)
+                # first hotword
+                start1 = random.randint(0, length // 3)
+                end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1)
+                # second hotword
+                start2 = random.randint(end1 + 1, length - hotword_min_length)
+                end2 = random.randint(min(length-1, start2+hotword_min_length-1), min(length-1, start2+hotword_max_length-1))
+                return [start1, end1, start2, end2]
+            else:  # single hotword
+                start = random.randint(0, length - hotword_min_length)
+                end = random.randint(min(length-1, start+hotword_min_length-1), min(length-1, start+hotword_max_length-1))
+                return [start, end]
+        else:
+            return [-1]
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/utils/low_frame_rate.py b/funasr/datasets/large_datasets/utils/low_frame_rate.py
new file mode 100644
index 0000000..76eb2da
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/low_frame_rate.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+
+def build_LFR_features(data, m, n):
+    """
+    Actually, this implements stacking frames and skipping frames.
+    if m = 1 and n = 1, just return the origin features.
+    if m = 1 and n > 1, it works like skipping.
+    if m > 1 and n = 1, it works like stacking but only support right frames.
+    if m > 1 and n > 1, it works like LFR.
+
+    Args:
+        inputs_batch: inputs is T x D np.ndarray
+        m: number of frames to stack
+        n: number of frames to skip
+    """
+
+    LFR_inputs = []
+    T = data.shape[0]
+    T_lfr = int(np.ceil(T / n))
+    for i in range(T_lfr):
+        if m <= T - i * n:
+            LFR_inputs.append(np.hstack(data[i*n:i*n+m]))
+        else:
+            num_padding = m - (T - i * n)
+            frame = np.hstack(data[i*n:])
+            for _ in range(num_padding):
+                frame = np.hstack((frame, data[-1]))
+            LFR_inputs.append(frame)
+    return np.vstack(LFR_inputs)
diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
new file mode 100644
index 0000000..26c6e84
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/padding.py
@@ -0,0 +1,74 @@
+import numpy as np
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+def padding(data, float_pad_value=0.0, int_pad_value=-1):
+    assert isinstance(data, list)
+    assert "key" in data[0]
+    assert "speech" in data[0] or "text" in data[0]
+    
+    keys = [x["key"] for x in data]
+
+    batch = {}
+    data_names = data[0].keys()
+    for data_name in data_names:
+        if data_name == "key" or data_name == "sampling_rate":
+            continue
+        else:
+            if data_name != 'hotword_indxs':
+                if data[0][data_name].dtype.kind == "i":
+                    pad_value = int_pad_value
+                    tensor_type = torch.int64
+                else:
+                    pad_value = float_pad_value
+                    tensor_type = torch.float32
+
+            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+            tensor_pad = pad_sequence(tensor_list,
+                                      batch_first=True,
+                                      padding_value=pad_value)
+            batch[data_name] = tensor_pad
+            batch[data_name + "_lengths"] = tensor_lengths
+
+    # SAC LABEL INCLUDE
+    if "hotword_indxs" in batch:
+        # if hotword indxs in batch
+        # use it to slice hotwords out
+        hotword_list = []
+        hotword_lengths = []
+        text = batch['text']
+        text_lengths = batch['text_lengths']
+        hotword_indxs = batch['hotword_indxs']
+        dha_pad = torch.ones_like(text) * -1
+        _, t1 = text.shape
+        t1 += 1  # TODO: as parameter which is same as predictor_bias
+        nth_hw = 0
+        for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
+            dha_pad[b][:length] = 8405
+            if hotword_indx[0] != -1:
+                start, end = int(hotword_indx[0]), int(hotword_indx[1])
+                hotword = one_text[start: end+1]
+                hotword_list.append(hotword)
+                hotword_lengths.append(end-start+1)
+                dha_pad[b][start: end+1] = one_text[start: end+1]
+                nth_hw += 1
+                if len(hotword_indx) == 4 and hotword_indx[2] != -1:
+                    # the second hotword if exist
+                    start, end = int(hotword_indx[2]), int(hotword_indx[3])
+                    hotword_list.append(one_text[start: end+1])
+                    hotword_lengths.append(end-start+1)
+                    dha_pad[b][start: end+1] = one_text[start: end+1]
+                    nth_hw += 1
+        hotword_list.append(torch.tensor([1]))
+        hotword_lengths.append(1)
+        hotword_pad = pad_sequence(hotword_list,
+                                batch_first=True,
+                                padding_value=0)
+        batch["hotword_pad"] = hotword_pad
+        batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
+        batch['dha_pad'] = dha_pad
+        del batch['hotword_indxs']
+        del batch['hotword_indxs_lengths']
+    return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
new file mode 100644
index 0000000..34a97c1
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -0,0 +1,95 @@
+#!/usr/bin/env python
+import re
+import numpy as np
+from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword
+
+def forward_segment(text, seg_dict):
+    word_list = []
+    i = 0
+    while i < len(text):
+        longest_word = text[i]
+        for j in range(i + 1, len(text) + 1):
+            word = text[i:j]
+            if word in seg_dict:
+                if len(word) > len(longest_word):
+                    longest_word = word
+        word_list.append(longest_word)
+        i += len(longest_word)
+    return word_list
+
+def seg_tokenize(txt, seg_dict):
+    pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+    out_txt = ""
+    for word in txt:
+        word = word.lower()
+        if word in seg_dict:
+            out_txt += seg_dict[word] + " "
+        else:
+            if pattern.match(word):
+                for char in word:
+                    if char in seg_dict:
+                        out_txt += seg_dict[char] + " "
+                    else:
+                        out_txt += "<unk>" + " "
+            else:
+                out_txt += "<unk>" + " "
+    return out_txt.strip().split()
+
+def tokenize(data,
+             vocab=None,
+             seg_dict=None,
+             punc_dict=None,
+             bpe_tokenizer=None,
+             hw_config=None):
+    assert "text" in data
+    assert isinstance(vocab, dict)
+    text = data["text"]
+    token = []
+    vad = -2
+    if bpe_tokenizer is not None:
+        text = bpe_tokenizer.text2tokens(" ".join(text))
+    if seg_dict is not None:
+        assert isinstance(seg_dict, dict)
+        text = seg_tokenize(text, seg_dict)
+
+    length = len(text)
+    if 'hw_tag' in data:
+        pre_index = None
+        if hw_config['pre_hwlist'] is not None and hw_config['pre_prob'] > 0:
+            # enable preset hotword detect in sampling
+            for hw in hw_config['pre_hwlist']:
+                hw = " ".join(seg_tokenize(hw, seg_dict))
+                _find = " ".join(text).find(hw)
+                if _find != -1:
+                    # _find = text[:_find].count(" ")  # bpe sometimes
+                    pre_index = [_find, _find + max(hw.count(" "), 1)]
+                    break
+        hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
+        data['hotword_indxs'] = hotword_indxs
+        del data['hw_tag']
+    for i in range(length):
+        x = text[i]
+        if i == length-1 and "punc" in data and x.startswith("vad:"):
+            vad = x[4:]
+            if len(vad) == 0:
+                vad = -1
+            else:
+                vad = int(vad)
+        elif x in vocab:
+            token.append(vocab[x])
+        else:
+            token.append(vocab['<unk>'])
+
+    if "punc" in data and punc_dict is not None:
+        punc_token = []
+        for punc in data["punc"]:
+            if punc in punc_dict:
+                punc_token.append(punc_dict[punc])
+            else:
+                punc_token.append(punc_dict["_"])
+        data["punc"] =  np.array(punc_token)
+
+    data["text"] = np.array(token)
+    if vad is not -2:
+        data["vad_indexes"]=np.array([vad], dtype=np.int64)
+    return data
diff --git a/funasr/datasets/llm_datasets/samplers.py b/funasr/datasets/llm_datasets/samplers.py
deleted file mode 100644
index 914e776..0000000
--- a/funasr/datasets/llm_datasets/samplers.py
+++ /dev/null
@@ -1,277 +0,0 @@
-import torch
-import numpy as np
-import logging
-import torch.distributed as dist
-
-from funasr.register import tables
-
-
-@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
-class BatchSampler(torch.utils.data.BatchSampler):
-    
-    def __init__(self, dataset,
-                 batch_type: str = "example",
-                 batch_size: int = 100,
-                 buffer_size: int = 30,
-                 drop_last: bool = False,
-                 shuffle: bool = True,
-                 is_training: bool = True,
-                 **kwargs):
-        
-        self.drop_last = drop_last
-        self.pre_idx = -1
-        self.dataset = dataset
-        self.total_samples = len(dataset)
-        self.batch_type = batch_type
-        self.batch_size = int(batch_size)
-        self.buffer_size = buffer_size
-        self.max_token_length = kwargs.get("max_token_length", 5000)
-        self.shuffle_idx = np.arange(self.total_samples)
-        self.shuffle = shuffle and is_training
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        
-    
-    def __len__(self):
-        return (self.total_samples-1) // self.batch_size + 1
-    
-    def set_epoch(self, epoch):
-        np.random.seed(epoch)
-    
-    def __iter__(self):
-        
-        if self.shuffle:
-            np.random.shuffle(self.shuffle_idx)
-        
-        batch = []
-        max_token = 0
-        num_sample = 0
-        
-        iter_num = (self.total_samples - 1) // self.buffer_size + 1
-        # print("iter_num: ", iter_num)
-        for iter in range(self.pre_idx + 1, iter_num):
-            datalen_with_index = []
-            for i in range(self.buffer_size):
-                idx = iter * self.buffer_size + i
-                if idx >= self.total_samples:
-                    continue
-                
-                idx_map = self.shuffle_idx[idx]
-                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
-                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
-                sample_len_cur = source_len + target_len
-                
-                
-                datalen_with_index.append([idx, sample_len_cur])
-            
-            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-            for item in datalen_with_index_sort:
-                idx, sample_len_cur_raw = item
-                if sample_len_cur_raw > self.max_token_length:
-                    continue
-                
-                max_token_cur = max(max_token, sample_len_cur_raw)
-                max_token_padding = 1 + num_sample
-                if self.batch_type != 'example':
-                    max_token_padding *= max_token_cur
-                if max_token_padding <= self.batch_size:
-                    batch.append(idx)
-                    max_token = max_token_cur
-                    num_sample += 1
-                else:
-                    yield batch
-                    batch = [idx]
-                    max_token = sample_len_cur_raw
-                    num_sample = 1
-
-
-@tables.register("batch_sampler_classes", "BatchSampler")
-@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler")
-class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler):
-    
-    def __init__(self, dataset,
-                 batch_type: str = "example",
-                 batch_size: int = 100,
-                 buffer_size: int = 30,
-                 drop_last: bool = True,
-                 shuffle: bool = True,
-                 is_training: bool = True,
-                 **kwargs):
-        
-        self.drop_last = drop_last
-        self.pre_idx = -1
-        self.dataset = dataset
-        self.total_samples = len(dataset)
-        self.batch_type = batch_type
-        self.batch_size = int(batch_size)
-        self.buffer_size = buffer_size
-        self.max_token_length = kwargs.get("max_token_length", 1500)
-        self.shuffle_idx = np.arange(self.total_samples)
-        self.shuffle = shuffle and is_training
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        
-        try:
-            rank = dist.get_rank()
-            world_size = dist.get_world_size()
-        except:
-            rank = 0
-            world_size = 1
-        self.rank = rank
-        self.world_size = world_size
-        
-    def __len__(self):
-        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-    
-    def set_epoch(self, epoch):
-        np.random.seed(epoch)
-    
-    def __iter__(self):
-    
-        batch_size_total = self.batch_size * self.world_size
-        
-        if self.shuffle:
-            np.random.shuffle(self.shuffle_idx)
-        
-        batch = []
-        max_token = 0
-        num_sample = 0
-        
-        iter_num = (self.total_samples - 1) // self.buffer_size + 1
-        # print("iter_num: ", iter_num)
-        for iter in range(self.pre_idx + 1, iter_num):
-            # if iter == iter_num -1 and self.drop_last:
-            #     continue
-            datalen_with_index = []
-            for i in range(self.buffer_size):
-                idx = iter * self.buffer_size + i
-                if idx >= self.total_samples:
-                    continue
-                
-                idx_map = self.shuffle_idx[idx]
-                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-                
-                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
-                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
-                sample_len_cur = source_len + target_len
-                
-                datalen_with_index.append([idx, sample_len_cur])
-            
-            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-            for item in datalen_with_index_sort:
-                idx, sample_len_cur_raw = item
-                if sample_len_cur_raw > self.max_token_length:
-                    continue
-
-                max_token_cur = max(max_token, sample_len_cur_raw)
-                max_token_padding = 1 + num_sample
-                # if self.batch_type != 'example':
-                #     max_token_padding *= max_token_cur
-                if max_token_padding <= batch_size_total:
-                    batch.append(idx)
-                    max_token = max_token_cur
-                    num_sample += 1
-                else:
-                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
-                    yield batch_rank
-                    batch = [idx]
-                    max_token = sample_len_cur_raw
-                    num_sample = 1
-
-
-@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler")
-class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler):
-    
-    def __init__(self, dataset,
-                 batch_type: str = "example",
-                 batch_size: int = 100,
-                 buffer_size: int = 30,
-                 drop_last: bool = True,
-                 shuffle: bool = True,
-                 is_training: bool = True,
-                 **kwargs):
-        
-        self.drop_last = drop_last
-        self.pre_idx = -1
-        self.dataset = dataset
-        self.total_samples = len(dataset)
-        self.batch_type = batch_type
-        self.batch_size = int(batch_size)
-        self.buffer_size = buffer_size
-        self.max_token_length = kwargs.get("max_token_length", 1500)
-        self.shuffle_idx = np.arange(self.total_samples)
-        self.shuffle = shuffle and is_training
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        
-        try:
-            rank = dist.get_rank()
-            world_size = dist.get_world_size()
-        except:
-            rank = 0
-            world_size = 1
-        self.rank = rank
-        self.world_size = world_size
-    
-    def __len__(self):
-        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-    
-    def set_epoch(self, epoch):
-        np.random.seed(epoch)
-    
-    def __iter__(self):
-        
-        batch_size_total = self.batch_size * self.world_size
-        if self.shuffle:
-            np.random.shuffle(self.shuffle_idx)
-        
-        batch_list_all_rank = []
-        batch_list_cur = []
-        max_token = 0
-        num_sample = 0
-        
-        iter_num = (self.total_samples - 1) // self.buffer_size + 1
-        # print("iter_num: ", iter_num)
-        for iter in range(self.pre_idx + 1, iter_num):
-            # if iter == iter_num - 1 and self.drop_last:
-            #     continue
-            datalen_with_index = []
-            for i in range(self.buffer_size):
-                idx = iter * self.buffer_size + i
-                if idx >= self.total_samples:
-                    continue
-                
-                idx_map = self.shuffle_idx[idx]
-                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-                
-                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
-                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
-                sample_len_cur = source_len + target_len
-                
-                datalen_with_index.append([idx, sample_len_cur])
-            
-            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-            for ii, item in enumerate(datalen_with_index_sort):
-                is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort)
-                idx, sample_len_cur_raw = item
-                if sample_len_cur_raw > self.max_token_length:
-                    continue
-                
-                max_token_cur = max(max_token, sample_len_cur_raw)
-                max_token_padding = 1 + num_sample
-                
-                if self.batch_type != 'example':
-                    max_token_padding *= max_token_cur
-                if len(batch_list_all_rank) < self.world_size:
-                    
-                    if max_token_padding <= self.batch_size:
-                        batch_list_cur.append(idx)
-                        max_token = max_token_cur
-                        num_sample += 1
-                    else:
-                        batch_list_all_rank.append(batch_list_cur)
-                        batch_list_cur = []
-                else:
-                    batch_rank = batch_list_all_rank[self.rank]
-                    yield batch_rank
-                    batch_list_all_rank = [idx]
-                    max_token = sample_len_cur_raw
-                    num_sample = 1
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
deleted file mode 100644
index 61f7d94..0000000
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ /dev/null
@@ -1,431 +0,0 @@
-import torch
-import numpy as np
-import logging
-import math
-import torch.distributed as dist
-from torch.utils.data import DistributedSampler
-from torch.utils.data import BatchSampler, Sampler
-import torch.distributed as dist
-
-from funasr.register import tables
-
-
-@tables.register("batch_sampler_classes", "RankFullGlobalShuffleBatchSampler")
-class RankFullGlobalShuffleBatchSampler(torch.utils.data.BatchSampler):
-    
-    def __init__(self, dataset,
-                 batch_type: str = "example",
-                 batch_size: int = 100,
-                 buffer_size: int = 30,
-                 drop_last: bool = True,
-                 shuffle: bool = True,
-                 is_training: bool = True,
-                 **kwargs):
-        
-        self.drop_last = drop_last
-        self.pre_idx = -1
-        self.dataset = dataset
-        self.total_samples = len(dataset)
-        self.batch_type = batch_type
-        self.batch_size = int(batch_size)
-        self.buffer_size = buffer_size
-        self.max_token_length = kwargs.get("max_token_length", 1500)
-        self.shuffle_idx = np.arange(self.total_samples)
-        self.shuffle = shuffle and is_training
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        
-        try:
-            rank = dist.get_rank()
-            world_size = dist.get_world_size()
-        except:
-            rank = 0
-            world_size = 1
-        self.rank = rank
-        self.world_size = world_size
-        
-    def __len__(self):
-        return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1
-    
-    def set_epoch(self, epoch):
-        np.random.seed(epoch)
-    
-    def __iter__(self):
-    
-        batch_size_total = self.batch_size * self.world_size
-        
-        if self.shuffle:
-            np.random.shuffle(self.shuffle_idx)
-        
-        batch = []
-        max_token = 0
-        num_sample = 0
-        
-        iter_num = (self.total_samples - 1) // self.buffer_size + 1
-        # print("iter_num: ", iter_num)
-        for iter in range(self.pre_idx + 1, iter_num):
-            # if iter == iter_num -1 and self.drop_last:
-            #     continue
-            datalen_with_index = []
-            for i in range(self.buffer_size):
-                idx = iter * self.buffer_size + i
-                if idx >= self.total_samples:
-                    continue
-                
-                idx_map = self.shuffle_idx[idx]
-                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-                
-                source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
-                target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
-                sample_len_cur = source_len + target_len
-                
-                datalen_with_index.append([idx, sample_len_cur])
-            
-            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-            for item in datalen_with_index_sort:
-                idx, sample_len_cur_raw = item
-                if sample_len_cur_raw > self.max_token_length:
-                    continue
-
-                max_token_cur = max(max_token, sample_len_cur_raw)
-                max_token_padding = 1 + num_sample
-                # if self.batch_type != 'example':
-                #     max_token_padding *= max_token_cur
-                if max_token_padding <= batch_size_total:
-                    batch.append(idx)
-                    max_token = max_token_cur
-                    num_sample += 1
-                else:
-                    batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size]
-                    yield batch_rank
-                    batch = [idx]
-                    max_token = sample_len_cur_raw
-                    num_sample = 1
-
-@tables.register("batch_sampler_classes", "DistributedSamplerWarp")
-class DistributedSamplerWarp(BatchSampler):
-    def __init__(self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False):
-        if num_replicas is None:
-            if not torch.distributed.is_available():
-                raise RuntimeError("Requires distributed package to be available")
-            num_replicas = torch.distributed.get_world_size()
-        if rank is None:
-            if not torch.distributed.is_available():
-                raise RuntimeError("Requires distributed package to be available")
-            rank = torch.distributed.get_rank()
-        
-        self.dataset = dataset
-        self.batch_size = batch_size
-        self.num_replicas = num_replicas
-        self.rank = rank
-        self.shuffle = shuffle
-        self.drop_last = drop_last
-        
-        # Create an instance of the DistributedSampler
-        self.sampler = DistributedSampler(
-            self.dataset,
-            num_replicas=self.num_replicas,
-            rank=self.rank,
-            shuffle=self.shuffle
-        )
-        
-        # Call BatchSampler's constructor
-        super().__init__(self.sampler, batch_size, drop_last)
-    
-    def __iter__(self):
-        # If we shuffle, we need to call the set_epoch method
-        if self.shuffle:
-            self.sampler.set_epoch(self.epoch)
-        
-        # Generate batch indices using the parent class
-        return super().__iter__()
-    
-    def set_epoch(self, epoch):
-        self.epoch = epoch
-
-@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn")
-def CustomDistributedBatchSampler_fn(dataset, **kwargs):
-    dataloader_args = {}
-    dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
-    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
-    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
-    
-    return dataloader_args
-
-@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
-class CustomDistributedBatchSampler(Sampler):
-    def __init__(self, dataset,
-                 batch_size,
-                 num_replicas=None,
-                 rank=None,
-                 shuffle=True,
-                 drop_last=False,
-                 is_training: bool = True,
-                 **kwargs,
-                 ):
-
-        try:
-            rank = dist.get_rank()
-            num_replicas = dist.get_world_size()
-        except:
-            rank = 0
-            num_replicas = 1
-        self.rank = rank
-        self.num_replicas = num_replicas
-        self.dataset = dataset
-        self.batch_size = batch_size
-        self.is_training = is_training
-        self.shuffle = shuffle and is_training
-        self.drop_last = drop_last
-        # self.total_size = len(dataset)
-        if self.drop_last:
-            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
-        else:
-            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
-        self.num_samples = int(self.total_size // self.num_replicas)
-        self.epoch = 0
-        self.max_token_length = kwargs.get("max_token_length", None)
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
-    def __iter__(self):
-        # Generate a list of indices
-        if self.shuffle:
-            g = torch.Generator()
-            g.manual_seed(self.epoch)
-            indices = torch.randperm(len(self.dataset), generator=g).tolist()
-        else:
-            indices = list(range(len(self.dataset)))
-
-        # Add extra samples to make it evenly divisible
-        padding_size = self.total_size - len(indices)
-        if padding_size <= len(indices):
-            indices += indices[:padding_size]
-        else:
-            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
-
-        assert len(indices) == self.total_size
-
-        # Subsample
-        indices = indices[self.rank:self.total_size:self.num_replicas]
-        assert len(indices) == self.num_samples
-
-        # Filter out indices with length greater than the max length, if provided
-        if self.max_token_length is not None:
-            filtered_indices = []
-            for idx in indices:
-                source_len = self.dataset.get_source_len(idx) / self.length_scale_source
-                if source_len <= self.max_token_length:
-                    filtered_indices.append(idx)
-            indices = filtered_indices
-
-        # Now that we have only the indices for this replica, chunk them into batches
-        batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
-
-        # Drop the last batch if it's not full and drop_last is True
-        if self.drop_last and len(batches[-1]) != self.batch_size:
-            batches = batches[:-1]
-
-        return iter(batches)
-
-    def __len__(self):
-
-        return self.num_samples // self.batch_size
-
-    def set_epoch(self, epoch):
-        self.epoch = epoch
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler_fn")
-def CustomDistributedBatchSampler_fn(dataset, **kwargs):
-    dataloader_args = {}
-    dataloader_args["batch_sampler"] = CustomDistributedBufferBatchSampler(dataset, **kwargs)
-    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
-    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
-    
-    return dataloader_args
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler")
-class CustomDistributedBatchSampler(Sampler):
-    def __init__(self, dataset,
-                 batch_size,
-                 num_replicas=None,
-                 rank=None,
-                 shuffle=True,
-                 drop_last=False,
-                 is_training: bool = True,
-                 sort_size: int = 1024,
-                 **kwargs,
-                 ):
-        
-        try:
-            rank = dist.get_rank()
-            num_replicas = dist.get_world_size()
-        except:
-            rank = 0
-            num_replicas = 1
-        self.rank = rank
-        self.num_replicas = num_replicas
-        self.dataset = dataset
-        self.batch_size = batch_size
-        self.is_training = is_training
-        self.shuffle = shuffle and is_training
-        self.drop_last = drop_last
-        # self.total_size = len(dataset)
-        if self.drop_last:
-            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
-        else:
-            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
-        self.num_samples = int(self.total_size // self.num_replicas)
-        self.epoch = 0
-        self.max_token_length = kwargs.get("max_token_length", None)
-        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-        self.sort_size = sort_size
-    
-    def __iter__(self):
-        # Generate a list of indices
-        if self.shuffle:
-            g = torch.Generator()
-            g.manual_seed(self.epoch)
-            indices = torch.randperm(len(self.dataset), generator=g).tolist()
-        else:
-            indices = list(range(len(self.dataset)))
-        
-        # Add extra samples to make it evenly divisible
-        padding_size = self.total_size - len(indices)
-        if padding_size <= len(indices):
-            indices += indices[:padding_size]
-        else:
-            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
-        
-        assert len(indices) == self.total_size
-        
-        # Subsample
-        indices = indices[self.rank:self.total_size:self.num_replicas]
-        assert len(indices) == self.num_samples
-        
-        # Filter out indices with length greater than the max length, if provided
-        if self.max_token_length is not None:
-            filtered_indices = []
-            for idx in indices:
-                source_len = self.dataset.get_source_len(idx) / self.length_scale_source
-                if source_len <= self.max_token_length:
-                    filtered_indices.append(idx)
-            indices = filtered_indices
-
-        # Buffer sorting logic
-        sorted_batches = []
-        buffer = []
-
-        for idx in indices:
-            buffer.append(idx)
-            if len(buffer) >= self.sort_size:
-                # Sort the buffer based on some criteria, e.g., dataset sample length
-                buffer.sort(key=lambda x: self.dataset.get_source_len(x))
-                sorted_batches.extend(self._create_batches_from_buffer(buffer))
-                buffer = []
-
-        # Handle the remaining items in the buffer
-        if buffer:
-            buffer.sort(key=lambda x: self.dataset.get_source_len(x))
-            sorted_batches.extend(self._create_batches_from_buffer(buffer))
-
-        return iter(sorted_batches)
-
-    def _create_batches_from_buffer(self, buffer):
-        # Function to convert the sorted buffer into batches
-        batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
-        if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
-            batched_buffer = batched_buffer[:-1]
-        return batched_buffer
-    
-    def __len__(self):
-        
-        return self.num_samples // self.batch_size
-    
-    def set_epoch(self, epoch):
-        self.epoch = epoch
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler_fn")
-def CustomDistributedBatchSampler_fn(dataset, **kwargs):
-    dataloader_args = {}
-    dataloader_args["batch_sampler"] = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
-    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
-    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
-    
-    return dataloader_args
-
-
-@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler")
-class CustomDistributedDynamicBatchSampler(Sampler):
-    def __init__(self, dataset,
-                 batch_size,
-                 num_replicas=None,
-                 rank=None,
-                 shuffle=True,
-                 drop_last=False,
-                 is_training: bool = True,
-                 **kwargs,
-                 ):
-        
-        try:
-            rank = dist.get_rank()
-            num_replicas = dist.get_world_size()
-        except:
-            rank = 0
-            num_replicas = 1
-        self.rank = rank
-        self.num_replicas = num_replicas
-        self.dataset = dataset
-        self.batch_size = batch_size
-        self.is_training = is_training
-        self.shuffle = shuffle and is_training
-        self.drop_last = drop_last
-        
-        self.total_size = len(self.dataset)
-        # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
-        self.epoch = 0
-    
-    def __iter__(self):
-        if self.shuffle:
-            g = torch.Generator()
-            g.manual_seed(self.epoch)
-            indices = torch.randperm(len(self.dataset), generator=g).tolist()
-        else:
-            indices = list(range(len(self.dataset)))
-        
-        indices = indices[self.rank:self.total_size:self.num_replicas]
-        
-        batches = []
-        batch = []
-        max_len_in_batch = 0
-        current_batch_length = 0
-        
-        for idx in indices:
-            sample_length = self.dataset.get_source_len(idx)
-            potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
-                    len(batch) + 1)
-            
-            if potential_batch_length <= self.batch_size:
-                batch.append(idx)
-                if sample_length > max_len_in_batch:
-                    max_len_in_batch = sample_length
-                    current_batch_length = max_len_in_batch * len(batch)
-            else:
-                batches.append(batch)
-                batch = [idx]
-                max_len_in_batch = sample_length
-                current_batch_length = max_len_in_batch
-        
-        # Add the last batch if it's not empty and we're not dropping it
-        if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
-            batches.append(batch)
-        
-        return iter(batches)
-    
-    def __len__(self):
-        
-        return -1
-    
-    def set_epoch(self, epoch):
-        self.epoch = epoch
diff --git a/funasr/register.py b/funasr/register.py
index cfa1b20..45e2a85 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -15,6 +15,7 @@
     predictor_classes = {}
     stride_conv_classes = {}
     tokenizer_classes = {}
+    dataloader_classes = {}
     batch_sampler_classes = {}
     dataset_classes = {}
     index_ds_classes = {}
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index aae4513..c443c6f 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -1,3 +1,4 @@
+import math
 import os
 import time
 import torch
@@ -7,8 +8,6 @@
 import torch.distributed as dist
 from torch.cuda.amp import autocast, GradScaler
 from contextlib import nullcontext, contextmanager
-# from torch.utils.tensorboard import SummaryWriter
-from tensorboardX import SummaryWriter
 from pathlib import Path
 
 from funasr.train_utils.device_funcs import to_device
@@ -40,11 +39,7 @@
         resume (str, optional): Path to a checkpoint to resume training from.
     """
     
-    def __init__(self, model,
-                 optim,
-                 scheduler,
-                 dataloader_train,
-                 dataloader_val,
+    def __init__(self,
                  local_rank,
                  use_ddp: bool = False,
                  use_fsdp: bool = False,
@@ -66,29 +61,31 @@
                       resume (str, optional): The file path to a checkpoint to resume training from.
         """
         
-        self.model = model
-        self.optim = optim
-        self.scheduler = scheduler
-        self.dataloader_train = dataloader_train
-        self.dataloader_val = dataloader_val
         self.output_dir = output_dir
+        if not os.path.exists(self.output_dir):
+            os.makedirs(self.output_dir, exist_ok=True)
         self.resume = kwargs.get('resume', True)
         self.start_epoch = 0
         self.max_epoch = kwargs.get('max_epoch', 100)
         self.local_rank = local_rank
         self.use_ddp = use_ddp
         self.use_fsdp = use_fsdp
-        self.device = next(model.parameters()).device
+        self.device = kwargs.get('device', "cuda")
         self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
-        self.kwargs = kwargs
+        # self.kwargs = kwargs
         self.log_interval = kwargs.get("log_interval", 50)
         self.batch_total = 0
         self.use_fp16 = use_fp16
         self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
-        scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
-        scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
-        self.scaler = scaler
+        # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
+        # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
+        # self.scaler = scaler
         self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+        self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
+        self.accum_grad = kwargs.get("accum_grad", 1)
+        self.grad_clip = kwargs.get("grad_clip", 10.0)
+        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
+        self.validate_interval = kwargs.get("validate_interval", 5000)
         
     
         try:
@@ -100,13 +97,22 @@
             logging.warning("distributed is not initialized, only single shard")
         self.rank = rank
         self.world_size = world_size
-        
-        os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
-        self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
-
-        
-    
-    def _save_checkpoint(self, epoch, step=None):
+        self.train_acc_avg = 0.0
+        self.train_loss_avg = 0.0
+        self.val_acc_avg = 0.0
+        self.val_loss_avg = 0.0
+        self.best_acc_idx = 0
+        self.saved_ckpts = {}
+        self.val_acc_list = []
+        self.step_or_epoch = -1
+       
+    def save_checkpoint(self, epoch,
+                        step=None,
+                        model=None,
+                        optim=None,
+                        scheduler=None,
+                        scaler=None,
+                        ):
         """
         Saves a checkpoint containing the model's state, the optimizer's state,
         and the scheduler's state at the end of the given epoch. This method is
@@ -115,29 +121,65 @@
         Args:
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
-        state = {
-            'epoch': epoch,
-            'state_dict': self.model.state_dict(),
-            'optimizer': self.optim.state_dict(),
-            'scheduler': self.scheduler.state_dict(),
-        }
-        if self.scaler:
-            state["scaler_state"] = self.scaler.state_dict()
-        # Create output directory if it does not exist
-        os.makedirs(self.output_dir, exist_ok=True)
-        if step is None:
-            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
-        else:
-            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
         
-        torch.save(state, filename)
-        
-        print(f'\nCheckpoint saved to {filename}\n')
-        latest = Path(os.path.join(self.output_dir, f'model.pt'))
-        torch.save(state, latest)
+        if self.rank == 0:
+            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+            self.step_or_epoch += 1
+            state = {
+                'epoch': epoch,
+                'state_dict': model.state_dict(),
+                'optimizer': optim.state_dict(),
+                'scheduler': scheduler.state_dict(),
+                "acc": self.val_acc_list,
+                "step_or_epoch": self.step_or_epoch,
+            }
+            if hasattr(model, "module"):
+                state["state_dict"] = model.module.state_dict()
+                
+            if scaler:
+                state["scaler_state"] = scaler.state_dict()
+            # Create output directory if it does not exist
+            os.makedirs(self.output_dir, exist_ok=True)
+            if step is None:
+                ckpt_name = f'model.pt.ep{epoch}'
+            else:
+                ckpt_name = f'model.pt.ep{epoch}.{step}'
+            filename = os.path.join(self.output_dir, ckpt_name)
+            torch.save(state, filename)
+            
+            logging.info(f'\nCheckpoint saved to {filename}\n')
+            latest = Path(os.path.join(self.output_dir, f'model.pt'))
+            torch.save(state, latest)
+            
+            if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
+                self.best_acc_idx = self.step_or_epoch
+                best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+                torch.save(state, best_ckpt)
+                logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
+            else:
+                logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
+            
+            if self.keep_nbest_models > 0:
+                self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
+                if len(self.saved_ckpts) > self.keep_nbest_models:
 
+                    min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+                    if min_key in self.saved_ckpts:
+                        del self.saved_ckpts[min_key]
+                    filename = os.path.join(self.output_dir, min_key)
+                    logging.info(f"Delete: {filename}")
+                    if os.path.exists(filename):
+                        os.remove(filename)
+                
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
     
-    def _resume_checkpoint(self, resume_path):
+    def resume_checkpoint(self,
+                          model=None,
+                          optim=None,
+                          scheduler=None,
+                          scaler=None,
+                          ):
         """
         Resumes training from a checkpoint at the given file path.
         Loads the model's state, the optimizer's state, and the scheduler's state.
@@ -145,114 +187,79 @@
         Args:
             resume_path (str): The file path to the checkpoint to resume from.
         """
-        ckpt = os.path.join(resume_path, "model.pt")
-        if os.path.isfile(ckpt):
-            checkpoint = torch.load(ckpt, map_location="cpu")
-            self.start_epoch = checkpoint['epoch'] + 1
-            # self.model.load_state_dict(checkpoint['state_dict'])
-            src_state = checkpoint['state_dict']
-            dst_state = self.model.state_dict()
-            for k in dst_state.keys():
-                if not k.startswith("module.") and "module."+k in src_state.keys():
-                    k_ddp = "module."+k
-                else:
-                    k_ddp = k
-                if k_ddp in src_state.keys():
-                    dst_state[k] = src_state[k_ddp]
-                else:
-                    print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
-
-            self.model.load_state_dict(dst_state)
-            self.optim.load_state_dict(checkpoint['optimizer'])
-            self.scheduler.load_state_dict(checkpoint['scheduler'])
-            if self.scaler and 'scaler_state' in checkpoint:
-                self.scaler.load_state_dict(checkpoint['scaler_state'])
-            print(f"Checkpoint loaded successfully from '{ckpt}'")
-        else:
-            print(f"No checkpoint found at '{ckpt}', does not resume status!")
-        
-        self.model.to(self.device)
-        if self.use_ddp or self.use_fsdp:
-            dist.barrier()
-        
-    def run(self):
-        """
-        Starts the training process, iterating over epochs, training the model,
-        and saving checkpoints at the end of each epoch.
-        """
         if self.resume:
-            self._resume_checkpoint(self.output_dir)
-        
-        for epoch in range(self.start_epoch, self.max_epoch + 1):
-            time1 = time.perf_counter()
-            self._train_epoch(epoch)
-
-
-            
-            if self.use_ddp or self.use_fsdp:
-                dist.barrier()
+            ckpt = os.path.join(self.output_dir, "model.pt")
+            if os.path.isfile(ckpt):
+                checkpoint = torch.load(ckpt)
+                self.start_epoch = checkpoint['epoch'] + 1
+                # self.model.load_state_dict(checkpoint['state_dict'])
+                src_state = checkpoint['state_dict']
+                dst_state = model.state_dict()
+                for k in dst_state.keys():
+                    if not k.startswith("module.") and "module."+k in src_state.keys():
+                        k_ddp = "module."+k
+                    else:
+                        k_ddp = k
+                    if k_ddp in src_state.keys():
+                        dst_state[k] = src_state[k_ddp]
+                    else:
+                        print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
+    
+                model.load_state_dict(dst_state)
+                optim.load_state_dict(checkpoint['optimizer'])
+                scheduler.load_state_dict(checkpoint['scheduler'])
+                if scaler is not None and 'scaler_state' in checkpoint:
+                    scaler.load_state_dict(checkpoint['scaler_state'])
                 
-            self._validate_epoch(epoch)
-
-            if self.use_ddp or self.use_fsdp:
-                dist.barrier()
-           
-           
-            if self.rank == 0:
-                self._save_checkpoint(epoch)
-            
-            if self.use_ddp or self.use_fsdp:
-                dist.barrier()
-            
-            self.scheduler.step()
-
-            time2 = time.perf_counter()
-            time_escaped = (time2 - time1)/3600.0
-            print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
-
-        if self.rank == 0:
-            average_checkpoints(self.output_dir, self.avg_nbest_model)
-            
+                self.val_acc_list = checkpoint["acc"]
+                self.step_or_epoch = checkpoint["step_or_epoch"]
+                
+                print(f"Checkpoint loaded successfully from '{ckpt}'")
+            else:
+                print(f"No checkpoint found at '{ckpt}', does not resume status!")
+    
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-
-
-        if self.writer:
-            self.writer.close()
         
-    
-    def _train_epoch(self, epoch):
+ 
+    def train_epoch(self,
+                model=None,
+                optim=None,
+                scheduler=None,
+                scaler=None,
+                dataloader_train=None,
+                dataloader_val=None,
+                epoch=None,
+                writer=None,
+                    ):
         """
         Defines the training process for a single epoch with gradient accumulation.
         Args:
             epoch (int): The current epoch number.
         """
-        self.model.train()
-        pbar = tqdm(colour="blue", desc=f"rank: {self.local_rank}, Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
-                    dynamic_ncols=True)
-        
+        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
+        model.train()
+
         # Set the number of steps for gradient accumulation
-        accum_grad = self.kwargs.get("accum_grad", 1)
+        accum_grad = self.accum_grad
         # Initialize the gradient accumulation
-        self.optim.zero_grad()
+        optim.zero_grad()
         speed_stats = {}
         time5 = time.perf_counter()
         
-        for batch_idx, batch in enumerate(self.dataloader_train):
+        for batch_idx, batch in enumerate(dataloader_train):
             self.batch_total += 1
             time1 = time.perf_counter()
             speed_stats["data_load"] = f"{time1-time5:0.3f}"
 
             batch = to_device(batch, self.device)
             
-            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+            my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
             with my_context():
                 time2 = time.perf_counter()
                 with maybe_autocast(self.use_fp16):
-                    retval = self.model(**batch)
+                    retval = model(**batch)
                     
-                if self.disable_gpu_cache: torch.cuda.empty_cache()
-
                 time3 = time.perf_counter()
                 speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval
@@ -261,95 +268,105 @@
                     # Apply weighted averaging for loss and stats
                     loss = (loss * weight.type(loss.dtype)).sum()
                     # if distributed, this method can also apply all_reduce()
-                    stats, weight = recursive_average(stats, weight, distributed=True)
+                    # stats, weight = recursive_average(stats, weight, distributed=True)
+                    if self.use_ddp or self.use_fsdp:
+                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                     # Now weight is summation over all workers
-                    loss /= weight
+                    loss /= weight.sum() # shape:[1] -> shape:[]
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss / accum_grad
                 if self.use_fp16:
-                    self.scaler.scale(loss).backward()
+                    scaler.scale(loss).backward()
                 else:
                     loss.backward()
                 time4 = time.perf_counter()
                 speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+                
+                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+                if "acc" in stats:
+                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+                if self.use_ddp or self.use_fsdp:
+                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                
             
             # Perform an optimizer step only after accumulating enough gradients
-            if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
+            if (batch_idx + 1) % accum_grad == 0:
                 # Perform gradient clipping if it is set
-                if self.kwargs.get("grad_clip", None) is not None:
+                if self.grad_clip > 0:
                     grad_norm = torch.nn.utils.clip_grad_norm_(
-                        self.model.parameters(),
-                        max_norm=self.kwargs.get("grad_clip", 10.0),
-                        norm_type=self.kwargs.get("grad_clip_type", 2.0),
+                        model.parameters(),
+                        max_norm=self.grad_clip,
+                        norm_type=self.grad_clip_type,
                     )
                     if not torch.isfinite(grad_norm):
                         logging.warning(
                             f"The grad norm is {grad_norm}. Skipping updating the model."
                         )
-                        self.optim.zero_grad()  # Reset gradients
+                        optim.zero_grad()  # Reset gradients
                         continue
                 
                 # Execute an optimization step (update model parameters)
                 if self.use_ddp or self.use_fsdp:
                     dist.barrier()
                 if self.use_fp16:
-                    self.scaler.step(self.optim)
-                    self.scaler.update()
+                    scaler.step(optim)
+                    scaler.update()
                 else:
-                    self.optim.step()
-                self.scheduler.step()
+                    optim.step()
+                scheduler.step()
                 # Clear gradients for the next accumulation stage
-                self.optim.zero_grad(set_to_none=True)
+                optim.zero_grad(set_to_none=True)
                 total_time = f"{time.perf_counter() - time5:0.3f}"
                 time5 = time.perf_counter()
                 speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
     
                 speed_stats["total_time"] = total_time
+                lr = scheduler.get_last_lr()[0]
+                batch_num_epoch = -1
+                if hasattr(dataloader_train, "__len__"):
+                    batch_num_epoch = len(dataloader_train)
+                self.log(epoch, batch_idx,
+                         batch_num_epoch=batch_num_epoch,
+                         lr=lr,
+                         loss=loss.detach().cpu().item(),
+                         speed_stats=speed_stats,
+                         stats=stats,
+                         writer=writer,
+                         tag="train",
+                         )
 
-
-            
-            if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_train):
-                pbar.update(self.log_interval)
-                gpu_info = "GPU, memory: {:.3f} GB, " \
-                           "{:.3f} GB, "\
-                           "{:.3f} GB, "\
-                           "{:.3f} GB".format(torch.cuda.memory_allocated()/1024/1024/1024,
-                                             torch.cuda.max_memory_allocated()/1024/1024/1024,
-                                             torch.cuda.memory_reserved()/1024/1024/1024,
-                                             torch.cuda.max_memory_reserved()/1024/1024/1024,
-                                             )
-                lr = self.scheduler.get_last_lr()[0]
-                time_now = datetime.now()
-                time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
-                description = (
-                    f"{time_now}, "
-                    f"rank: {self.local_rank}, "
-                    f"epoch: {epoch}/{self.max_epoch}, "
-                    f"step: {batch_idx+1}/{len(self.dataloader_train)}, total step: {self.batch_total}, "
-                    f"(loss: {loss.detach().cpu().item():.3f}), "
-                    f"(lr: {lr:.3e}), "
-                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
-                    f"{speed_stats}, "
-                    f"{gpu_info}"
+            if (batch_idx + 1) % self.validate_interval == 0:
+                self.validate_epoch(
+                    model=model,
+                    dataloader_val=dataloader_val,
+                    epoch=epoch,
+                    writer=writer
                 )
-                pbar.set_description(description)
-                if self.writer:
-                    self.writer.add_scalar(f'rank{self.local_rank}_Loss/train', loss.item(), self.batch_total)
-                    self.writer.add_scalar(f'rank{self.local_rank}_lr/train', lr, self.batch_total)
-                    for key, var in stats.items():
-                        self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', var.item(), self.batch_total)
-                    for key, var in speed_stats.items():
-                        self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
 
-            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
-                self._save_checkpoint(epoch, step=batch_idx+1)
-        pbar.close()
+            if (batch_idx+1) % self.save_checkpoint_interval == 0:
+                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
+
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        
         
 
-    def _validate_epoch(self, epoch):
+    def validate_epoch(self,
+                       model=None,
+                       dataloader_val=None,
+                       epoch=None,
+                       writer=None,
+                       **kwargs,
+                       ):
         """
         Defines the validation process for a single epoch.
         Should be implemented with the actual model validation steps.
@@ -357,18 +374,19 @@
         Args:
             epoch (int): The current epoch number.
         """
-        self.model.eval()
+        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
+        model.eval()
+        
         with torch.no_grad():
-            pbar = tqdm(colour="red", desc=f"rank: {self.local_rank}, Validation Epoch: {epoch + 1}", total=len(self.dataloader_val),
-                        dynamic_ncols=True)
+            
             speed_stats = {}
             time5 = time.perf_counter()
-            for batch_idx, batch in enumerate(self.dataloader_val):
+            for batch_idx, batch in enumerate(dataloader_val):
                 time1 = time.perf_counter()
                 speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                 batch = to_device(batch, self.device)
                 time2 = time.perf_counter()
-                retval = self.model(**batch)
+                retval = model(**batch)
                 time3 = time.perf_counter()
                 speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval
@@ -378,8 +396,10 @@
                     loss = (loss * weight.type(loss.dtype)).sum()
                     # if distributed, this method can also apply all_reduce()
                     stats, weight = recursive_average(stats, weight, distributed=True)
+                    if self.use_ddp or self.use_fsdp:
+                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                     # Now weight is summation over all workers
-                    loss /= weight
+                    loss /= weight.sum() # shape:[1] -> shape:[]
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
@@ -387,29 +407,94 @@
                 loss = loss
                 time4 = time.perf_counter()
 
+                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+                if "acc" in stats:
+                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+                if self.use_ddp or self.use_fsdp:
+                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                 
-                if (batch_idx+1) % self.log_interval == 0 or (batch_idx+1) == len(self.dataloader_val):
-                    pbar.update(self.log_interval)
-                    time_now = datetime.now()
-                    time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
-                    description = (
-                        f"{time_now}, "
-                        f"rank: {self.local_rank}, "
-                        f"validation epoch: {epoch}/{self.max_epoch}, "
-                        f"step: {batch_idx+1}/{len(self.dataloader_val)}, "
-                        f"(loss: {loss.detach().cpu().item():.3f}), "
-                        f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
-                        f"{speed_stats}, "
-                    )
-                    pbar.set_description(description)
-                    if self.writer:
-                        self.writer.add_scalar(f"rank{self.local_rank}_Loss/val", loss.item(),
-                                               epoch*len(self.dataloader_val) + batch_idx)
-                        for key, var in stats.items():
-                            self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', var.item(),
-                                                   epoch * len(self.dataloader_val) + batch_idx)
-                        for key, var in speed_stats.items():
-                            self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
-                                                   epoch * len(self.dataloader_val) + batch_idx)
+                batch_num_epoch = -1
+                if hasattr(dataloader_val, "__len__"):
+                    batch_num_epoch = len(dataloader_val)
+                self.log(epoch, batch_idx,
+                         batch_num_epoch=batch_num_epoch,
+                         lr=0.0,
+                         loss=loss.detach().cpu().item(),
+                         speed_stats=speed_stats,
+                         stats=stats,
+                         writer=writer,
+                         tag="val",
+                         )
 
-        self.model.train()
\ No newline at end of file
+        self.val_acc_list.append(self.val_acc_avg)
+        model.train()
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        
+        
+    def log(self,
+            epoch=0,
+            batch_idx=0,
+            batch_num_epoch=-1,
+            lr=0.0,
+            loss=0.0,
+            speed_stats=None,
+            stats=None,
+            writer=None,
+            tag="train",
+            ):
+        
+        if (batch_idx + 1) % self.log_interval == 0:
+            
+            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
+                       "peak: {:.3f} GB, " \
+                       "cache: {:.3f} GB, " \
+                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+                                          torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
+                                          torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
+                                          torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
+                                          )
+            
+            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
+            description = (
+                f"{tag}, "
+                f"rank: {self.local_rank}, "
+                f"epoch: {epoch}/{self.max_epoch}, "
+                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
+                f"(loss_avg_rank: {loss:.3f}), "
+                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
+                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
+                f"(lr: {lr:.3e}), "
+                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
+                f"{speed_stats}, "
+                f"{gpu_info}"
+            )
+            logging.info(description)
+            
+            if writer is not None:
+                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
+                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
+                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
+                for key, var in stats.items():
+                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
+                for key, var in speed_stats.items():
+                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
+        
+    def close(self, writer=None):
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        
+        if writer is not None:
+            writer.close()
+    
+        if self.use_ddp or self.use_fsdp:
+            torch.distributed.destroy_process_group()
\ No newline at end of file
diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py
deleted file mode 100644
index 5f13b5a..0000000
--- a/funasr/train_utils/trainer_llm.py
+++ /dev/null
@@ -1,502 +0,0 @@
-import math
-import os
-import time
-import torch
-import logging
-from tqdm import tqdm
-from datetime import datetime
-import torch.distributed as dist
-from torch.cuda.amp import autocast, GradScaler
-from contextlib import nullcontext, contextmanager
-from pathlib import Path
-
-from funasr.train_utils.device_funcs import to_device
-from funasr.train_utils.recursive_op import recursive_average
-from funasr.train_utils.average_nbest_models import average_checkpoints
-from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-
-@contextmanager
-def maybe_autocast(enabled):
-    if enabled:
-        with autocast():
-            yield
-    else:
-        yield
-
-class Trainer:
-    """
-    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
-    and optionally resuming from a saved checkpoint.
-
-    Attributes:
-        max_epoch (int): Maximum number of epochs for training.
-        model (torch.nn.Module): The model to be trained.
-        optim (torch.optim.Optimizer): The optimizer to use for training.
-        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
-        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
-        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
-        output_dir (str): Directory where model checkpoints will be saved.
-        resume (str, optional): Path to a checkpoint to resume training from.
-    """
-    
-    def __init__(self,
-                 local_rank,
-                 use_ddp: bool = False,
-                 use_fsdp: bool = False,
-                 use_fp16: bool = False,
-                 output_dir: str="./",
-                 **kwargs):
-        """
-        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
-
-        Args:
-            model (torch.nn.Module): The model to be trained.
-            optim (torch.optim.Optimizer): The optimizer to use for training.
-            scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
-            dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
-            dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
-            **kwargs: Additional keyword arguments:
-                      max_epoch (int): The maximum number of epochs for training.
-                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
-                      resume (str, optional): The file path to a checkpoint to resume training from.
-        """
-        
-        self.output_dir = output_dir
-        if not os.path.exists(self.output_dir):
-            os.makedirs(self.output_dir, exist_ok=True)
-        self.resume = kwargs.get('resume', True)
-        self.start_epoch = 0
-        self.max_epoch = kwargs.get('max_epoch', 100)
-        self.local_rank = local_rank
-        self.use_ddp = use_ddp
-        self.use_fsdp = use_fsdp
-        self.device = kwargs.get('device', "cuda")
-        self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
-        # self.kwargs = kwargs
-        self.log_interval = kwargs.get("log_interval", 50)
-        self.batch_total = 0
-        self.use_fp16 = use_fp16
-        self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
-        # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
-        # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
-        # self.scaler = scaler
-        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
-        self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
-        self.accum_grad = kwargs.get("accum_grad", 1)
-        self.grad_clip = kwargs.get("grad_clip", 10.0)
-        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
-        self.validate_interval = kwargs.get("validate_interval", 5000)
-        
-    
-        try:
-            rank = dist.get_rank()
-            world_size = dist.get_world_size()
-        except:
-            rank = 0
-            world_size = 1
-            logging.warning("distributed is not initialized, only single shard")
-        self.rank = rank
-        self.world_size = world_size
-        self.train_acc_avg = 0.0
-        self.train_loss_avg = 0.0
-        self.val_acc_avg = 0.0
-        self.val_loss_avg = 0.0
-        self.best_acc_idx = 0
-        self.saved_ckpts = {}
-        self.val_acc_list = []
-        self.step_or_epoch = -1
-        
-        
-
-        
-    
-    def save_checkpoint(self, epoch,
-                        step=None,
-                        model=None,
-                        optim=None,
-                        scheduler=None,
-                        scaler=None,
-                        ):
-        """
-        Saves a checkpoint containing the model's state, the optimizer's state,
-        and the scheduler's state at the end of the given epoch. This method is
-        intended to be called at the end of each epoch to save the training progress.
-
-        Args:
-            epoch (int): The epoch number at which the checkpoint is being saved.
-        """
-        
-        if self.rank == 0:
-            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
-            self.step_or_epoch += 1
-            state = {
-                'epoch': epoch,
-                'state_dict': model.state_dict(),
-                'optimizer': optim.state_dict(),
-                'scheduler': scheduler.state_dict(),
-                "acc": self.val_acc_list,
-                "step_or_epoch": self.step_or_epoch,
-            }
-            if hasattr(model, "module"):
-                state["state_dict"] = model.module.state_dict()
-                
-            if scaler:
-                state["scaler_state"] = scaler.state_dict()
-            # Create output directory if it does not exist
-            os.makedirs(self.output_dir, exist_ok=True)
-            if step is None:
-                ckpt_name = f'model.pt.ep{epoch}'
-            else:
-                ckpt_name = f'model.pt.ep{epoch}.{step}'
-            filename = os.path.join(self.output_dir, ckpt_name)
-            torch.save(state, filename)
-            
-            logging.info(f'\nCheckpoint saved to {filename}\n')
-            latest = Path(os.path.join(self.output_dir, f'model.pt'))
-            torch.save(state, latest)
-            
-            if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
-                self.best_acc_idx = self.step_or_epoch
-                best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
-                torch.save(state, best_ckpt)
-                logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
-            else:
-                logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
-            
-            if self.keep_nbest_models > 0:
-                self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
-                if len(self.saved_ckpts) > self.keep_nbest_models:
-
-                    min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
-                    if min_key in self.saved_ckpts:
-                        del self.saved_ckpts[min_key]
-                    filename = os.path.join(self.output_dir, min_key)
-                    logging.info(f"Delete: {filename}")
-                    if os.path.exists(filename):
-                        os.remove(filename)
-                
-        if self.use_ddp or self.use_fsdp:
-            dist.barrier()
-    
-    def resume_checkpoint(self,
-                          model=None,
-                          optim=None,
-                          scheduler=None,
-                          scaler=None,
-                          ):
-        """
-        Resumes training from a checkpoint at the given file path.
-        Loads the model's state, the optimizer's state, and the scheduler's state.
-
-        Args:
-            resume_path (str): The file path to the checkpoint to resume from.
-        """
-        if self.resume:
-            ckpt = os.path.join(self.output_dir, "model.pt")
-            if os.path.isfile(ckpt):
-                checkpoint = torch.load(ckpt)
-                self.start_epoch = checkpoint['epoch'] + 1
-                # self.model.load_state_dict(checkpoint['state_dict'])
-                src_state = checkpoint['state_dict']
-                dst_state = model.state_dict()
-                for k in dst_state.keys():
-                    if not k.startswith("module.") and "module."+k in src_state.keys():
-                        k_ddp = "module."+k
-                    else:
-                        k_ddp = k
-                    if k_ddp in src_state.keys():
-                        dst_state[k] = src_state[k_ddp]
-                    else:
-                        print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
-    
-                model.load_state_dict(dst_state)
-                optim.load_state_dict(checkpoint['optimizer'])
-                scheduler.load_state_dict(checkpoint['scheduler'])
-                if scaler is not None and 'scaler_state' in checkpoint:
-                    scaler.load_state_dict(checkpoint['scaler_state'])
-                
-                self.val_acc_list = checkpoint["acc"]
-                self.step_or_epoch = checkpoint["step_or_epoch"]
-                
-                print(f"Checkpoint loaded successfully from '{ckpt}'")
-            else:
-                print(f"No checkpoint found at '{ckpt}', does not resume status!")
-    
-        if self.use_ddp or self.use_fsdp:
-            dist.barrier()
-        
- 
-    def train_epoch(self,
-                model=None,
-                optim=None,
-                scheduler=None,
-                scaler=None,
-                dataloader_train=None,
-                dataloader_val=None,
-                epoch=None,
-                writer=None,
-                    ):
-        """
-        Defines the training process for a single epoch with gradient accumulation.
-        Args:
-            epoch (int): The current epoch number.
-        """
-        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
-        model.train()
-
-        # Set the number of steps for gradient accumulation
-        accum_grad = self.accum_grad
-        # Initialize the gradient accumulation
-        optim.zero_grad()
-        speed_stats = {}
-        time5 = time.perf_counter()
-        
-        for batch_idx, batch in enumerate(dataloader_train):
-            self.batch_total += 1
-            time1 = time.perf_counter()
-            speed_stats["data_load"] = f"{time1-time5:0.3f}"
-
-            batch = to_device(batch, self.device)
-            
-            my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
-            with my_context():
-                time2 = time.perf_counter()
-                with maybe_autocast(self.use_fp16):
-                    retval = model(**batch)
-                    
-                if self.disable_gpu_cache: torch.cuda.empty_cache()
-
-                time3 = time.perf_counter()
-                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
-                loss, stats, weight = retval
-                stats = {k: v for k, v in stats.items() if v is not None}
-                if self.use_ddp or self.use_fsdp:
-                    # Apply weighted averaging for loss and stats
-                    loss = (loss * weight.type(loss.dtype)).sum()
-                    # if distributed, this method can also apply all_reduce()
-                    stats, weight = recursive_average(stats, weight, distributed=True)
-                    # Now weight is summation over all workers
-                    loss /= weight
-                    # Multiply world_size because DistributedDataParallel
-                    # automatically normalizes the gradient by world_size.
-                    loss *= self.world_size
-                # Scale the loss since we're not updating for every mini-batch
-                loss = loss / accum_grad
-                if self.use_fp16:
-                    scaler.scale(loss).backward()
-                else:
-                    loss.backward()
-                time4 = time.perf_counter()
-                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
-                
-                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
-                if "acc" in stats:
-                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
-                if self.use_ddp or self.use_fsdp:
-                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
-                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
-                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
-                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
-                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
-                
-            
-            # Perform an optimizer step only after accumulating enough gradients
-            if (batch_idx + 1) % accum_grad == 0:
-                # Perform gradient clipping if it is set
-                if self.grad_clip > 0:
-                    grad_norm = torch.nn.utils.clip_grad_norm_(
-                        model.parameters(),
-                        max_norm=self.grad_clip,
-                        norm_type=self.grad_clip_type,
-                    )
-                    if not torch.isfinite(grad_norm):
-                        logging.warning(
-                            f"The grad norm is {grad_norm}. Skipping updating the model."
-                        )
-                        optim.zero_grad()  # Reset gradients
-                        continue
-                
-                # Execute an optimization step (update model parameters)
-                if self.use_ddp or self.use_fsdp:
-                    dist.barrier()
-                if self.use_fp16:
-                    scaler.step(optim)
-                    scaler.update()
-                else:
-                    optim.step()
-                scheduler.step()
-                # Clear gradients for the next accumulation stage
-                optim.zero_grad(set_to_none=True)
-                total_time = f"{time.perf_counter() - time5:0.3f}"
-                time5 = time.perf_counter()
-                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
-    
-                speed_stats["total_time"] = total_time
-                lr = scheduler.get_last_lr()[0]
-                batch_num_epoch = -1
-                if hasattr(dataloader_train, "__len__"):
-                    batch_num_epoch = len(dataloader_train)
-                self.log(epoch, batch_idx,
-                         batch_num_epoch=batch_num_epoch,
-                         lr=lr,
-                         loss=loss.detach().cpu().item(),
-                         speed_stats=speed_stats,
-                         stats=stats,
-                         writer=writer,
-                         tag="train",
-                         )
-
-            if (batch_idx + 1) % self.validate_interval == 0:
-                self.validate_epoch(
-                    model=model,
-                    dataloader_val=dataloader_val,
-                    epoch=epoch,
-                    writer=writer
-                )
-
-            if (batch_idx+1) % self.save_checkpoint_interval == 0:
-                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
-
-        
-        if self.use_ddp or self.use_fsdp:
-            dist.barrier()
-        
-        
-
-    def validate_epoch(self,
-                       model=None,
-                       dataloader_val=None,
-                       epoch=None,
-                       writer=None,
-                       **kwargs,
-                       ):
-        """
-        Defines the validation process for a single epoch.
-        Should be implemented with the actual model validation steps.
-    
-        Args:
-            epoch (int): The current epoch number.
-        """
-        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
-        model.eval()
-        
-        with torch.no_grad():
-            
-            speed_stats = {}
-            time5 = time.perf_counter()
-            for batch_idx, batch in enumerate(dataloader_val):
-                time1 = time.perf_counter()
-                speed_stats["data_load"] = f"{time1 - time5:0.3f}"
-                batch = to_device(batch, self.device)
-                time2 = time.perf_counter()
-                retval = model(**batch)
-                time3 = time.perf_counter()
-                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
-                loss, stats, weight = retval
-                stats = {k: v for k, v in stats.items() if v is not None}
-                if self.use_ddp or self.use_fsdp:
-                    # Apply weighted averaging for loss and stats
-                    loss = (loss * weight.type(loss.dtype)).sum()
-                    # if distributed, this method can also apply all_reduce()
-                    stats, weight = recursive_average(stats, weight, distributed=True)
-                    # Now weight is summation over all workers
-                    loss /= weight
-                    # Multiply world_size because DistributedDataParallel
-                    # automatically normalizes the gradient by world_size.
-                    loss *= self.world_size
-                # Scale the loss since we're not updating for every mini-batch
-                loss = loss
-                time4 = time.perf_counter()
-
-                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
-                if "acc" in stats:
-                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
-                if self.use_ddp or self.use_fsdp:
-                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
-                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
-                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
-                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
-                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
-                
-                batch_num_epoch = -1
-                if hasattr(dataloader_val, "__len__"):
-                    batch_num_epoch = len(dataloader_val)
-                self.log(epoch, batch_idx,
-                         batch_num_epoch=batch_num_epoch,
-                         lr=0.0,
-                         loss=loss.detach().cpu().item(),
-                         speed_stats=speed_stats,
-                         stats=stats,
-                         writer=writer,
-                         tag="val",
-                         )
-
-        self.val_acc_list.append(self.val_acc_avg)
-        model.train()
-        
-        if self.use_ddp or self.use_fsdp:
-            dist.barrier()
-        
-        
-    def log(self,
-            epoch=0,
-            batch_idx=0,
-            batch_num_epoch=-1,
-            lr=0.0,
-            loss=0.0,
-            speed_stats=None,
-            stats=None,
-            writer=None,
-            tag="train",
-            ):
-        
-        if (batch_idx + 1) % self.log_interval == 0:
-            
-            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
-                       "peak: {:.3f} GB, " \
-                       "cache: {:.3f} GB, " \
-                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
-                                          torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
-                                          torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
-                                          torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
-                                          )
-            
-            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
-            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
-            description = (
-                f"{tag}, "
-                f"rank: {self.local_rank}, "
-                f"epoch: {epoch}/{self.max_epoch}, "
-                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
-                f"(loss_avg_rank: {loss:.3f}), "
-                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
-                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
-                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
-                f"(lr: {lr:.3e}), "
-                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
-                f"{speed_stats}, "
-                f"{gpu_info}"
-            )
-            logging.info(description)
-            
-            if writer is not None:
-                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
-                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
-                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
-                for key, var in stats.items():
-                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
-                for key, var in speed_stats.items():
-                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
-        
-    def close(self, writer=None):
-        
-        if self.use_ddp or self.use_fsdp:
-            dist.barrier()
-        
-        if writer is not None:
-            writer.close()
-    
-        if self.use_ddp or self.use_fsdp:
-            torch.distributed.destroy_process_group()
\ No newline at end of file
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index a08f263..ef18a61 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -1,7 +1,9 @@
+import os
 import io
+import shutil
 from collections import OrderedDict
 import numpy as np
-
+from omegaconf import DictConfig, OmegaConf
 
 def statistic_model_parameters(model, prefix=None):
     var_dict = model.state_dict()
@@ -52,4 +54,21 @@
         if isinstance(value, dict) and key in original:
             deep_update(original[key], value)
         else:
-            original[key] = value
\ No newline at end of file
+            original[key] = value
+            
+            
+def prepare_model_dir(**kwargs):
+    
+
+    os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+    
+    yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+    OmegaConf.save(config=kwargs, f=yaml_file)
+    print(kwargs)
+    logging.info("config.yaml is saved to: %s", yaml_file)
+
+    # model_path = kwargs.get("model_path")
+    # if model_path is not None:
+    #     config_json = os.path.join(model_path, "configuration.json")
+    #     if os.path.exists(config_json):
+    #         shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json"))
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
deleted file mode 100644
index 36eebdc..0000000
--- a/funasr/utils/prepare_data.py
+++ /dev/null
@@ -1,242 +0,0 @@
-import logging
-import os
-import shutil
-from multiprocessing import Pool
-
-import kaldiio
-import numpy as np
-import librosa
-import torch.distributed as dist
-import torchaudio
-
-
-def filter_wav_text(data_dir, dataset):
-    wav_file = os.path.join(data_dir, dataset, "wav.scp")
-    text_file = os.path.join(data_dir, dataset, "text")
-    with open(wav_file) as f_wav, open(text_file) as f_text:
-        wav_lines = f_wav.readlines()
-        text_lines = f_text.readlines()
-    os.rename(wav_file, "{}.bak".format(wav_file))
-    os.rename(text_file, "{}.bak".format(text_file))
-    wav_dict = {}
-    for line in wav_lines:
-        parts = line.strip().split()
-        if len(parts) < 2:
-            continue
-        wav_dict[parts[0]] = parts[1]
-    text_dict = {}
-    for line in text_lines:
-        parts = line.strip().split()
-        if len(parts) < 2:
-            continue
-        text_dict[parts[0]] = " ".join(parts[1:])
-    filter_count = 0
-    with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
-        for sample_name, wav_path in wav_dict.items():
-            if sample_name in text_dict.keys():
-                f_wav.write(sample_name + " " + wav_path + "\n")
-                f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
-            else:
-                filter_count += 1
-    logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".
-                 format(filter_count, len(wav_lines), dataset))
-
-
-def wav2num_frame(wav_path, frontend_conf):
-    try:
-        waveform, sampling_rate = torchaudio.load(wav_path)
-    except:
-        waveform, sampling_rate = librosa.load(wav_path)
-        waveform = np.expand_dims(waveform, axis=0)
-    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
-    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
-    return n_frames, feature_dim
-
-
-def calc_shape_core(root_path, args, idx):
-    file_name = args.data_file_names.split(",")[0]
-    data_name = args.dataset_conf.get("data_names", "speech,text").split(",")[0]
-    scp_file = os.path.join(root_path, "{}.{}".format(file_name, idx))
-    shape_file = os.path.join(root_path, "{}_shape.{}".format(data_name, idx))
-    with open(scp_file) as f:
-        lines = f.readlines()
-    data_type = args.dataset_conf.get("data_types", "sound,text").split(",")[0]
-    if data_type == "sound":
-        frontend_conf = args.frontend_conf
-        dataset_conf = args.dataset_conf
-        length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "{}_length_min".format(data_name)) else -1
-        length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "{}_length_max".format(data_name)) else -1
-        with open(shape_file, "w") as f:
-            for line in lines:
-                sample_name, wav_path = line.strip().split()
-                n_frames, feature_dim = wav2num_frame(wav_path, frontend_conf)
-                write_flag = True
-                if n_frames > 0 and length_min > 0:
-                    write_flag = n_frames >= length_min
-                if n_frames > 0 and length_max > 0:
-                    write_flag = n_frames <= length_max
-                if write_flag:
-                    f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
-                    f.flush()
-    elif data_type == "kaldi_ark":
-        dataset_conf = args.dataset_conf
-        length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "{}_length_min".format(data_name)) else -1
-        length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "{}_length_max".format(data_name)) else -1
-        with open(shape_file, "w") as f:
-            for line in lines:
-                sample_name, feature_path = line.strip().split()
-                feature = kaldiio.load_mat(feature_path)
-                n_frames, feature_dim = feature.shape
-                write_flag = True
-                if n_frames > 0 and length_min > 0:
-                    write_flag = n_frames >= length_min
-                if n_frames > 0 and length_max > 0:
-                    write_flag = n_frames <= length_max
-                if write_flag:
-                    f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
-                    f.flush()
-    elif data_type == "text":
-        with open(shape_file, "w") as f:
-            for line in lines:
-                sample_name, text = line.strip().split(maxsplit=1)
-                n_tokens = len(text.split())
-                f.write("{} {}\n".format(sample_name, str(int(np.ceil(n_tokens)))))
-                f.flush()
-    else:
-        raise RuntimeError("Unsupported data_type: {}".format(data_type))
-
-
-def calc_shape(args, dataset, nj=64):
-    data_name = args.dataset_conf.get("data_names", "speech,text").split(",")[0]
-    shape_path = os.path.join(args.data_dir, dataset, "{}_shape".format(data_name))
-    if os.path.exists(shape_path):
-        logging.info('Shape file for small dataset already exists.')
-        return
-
-    split_shape_path = os.path.join(args.data_dir, dataset, "{}_shape_files".format(data_name))
-    if os.path.exists(split_shape_path):
-        shutil.rmtree(split_shape_path)
-    os.mkdir(split_shape_path)
-
-    # split
-    file_name = args.data_file_names.split(",")[0]
-    scp_file = os.path.join(args.data_dir, dataset, file_name)
-    with open(scp_file) as f:
-        lines = f.readlines()
-        num_lines = len(lines)
-        num_job_lines = num_lines // nj
-    start = 0
-    for i in range(nj):
-        end = start + num_job_lines
-        file = os.path.join(split_shape_path, "{}.{}".format(file_name, str(i + 1)))
-        with open(file, "w") as f:
-            if i == nj - 1:
-                f.writelines(lines[start:])
-            else:
-                f.writelines(lines[start:end])
-        start = end
-
-    p = Pool(nj)
-    for i in range(nj):
-        p.apply_async(calc_shape_core, args=(split_shape_path, args, str(i + 1)))
-    logging.info("Generating shape files, please wait a few minutes...")
-    p.close()
-    p.join()
-
-    # combine
-    with open(shape_path, "w") as f:
-        for i in range(nj):
-            job_file = os.path.join(split_shape_path, "{}_shape.{}".format(data_name, str(i + 1)))
-            with open(job_file) as job_f:
-                lines = job_f.readlines()
-                f.writelines(lines)
-    logging.info('Generating shape files done.')
-
-
-def generate_data_list(args, data_dir, dataset, nj=64):
-    data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
-    file_names = args.data_file_names.split(",")
-    concat_data_name = "_".join(data_names)
-    list_file = os.path.join(data_dir, dataset, "{}_data.list".format(concat_data_name))
-    if os.path.exists(list_file):
-        logging.info('Data list for large dataset already exists.')
-        return
-    split_path = os.path.join(data_dir, dataset, "split")
-    if os.path.exists(split_path):
-        shutil.rmtree(split_path)
-    os.mkdir(split_path)
-
-    data_lines_list = []
-    for file_name in file_names:
-        with open(os.path.join(data_dir, dataset, file_name)) as f:
-            lines = f.readlines()
-            data_lines_list.append(lines)
-    num_lines = len(data_lines_list[0])
-    num_job_lines = num_lines // nj
-    start = 0
-    for i in range(nj):
-        end = start + num_job_lines
-        split_path_nj = os.path.join(split_path, str(i + 1))
-        os.mkdir(split_path_nj)
-        for file_id, file_name in enumerate(file_names):
-            file = os.path.join(split_path_nj, file_name)
-            with open(file, "w") as f:
-                if i == nj - 1:
-                    f.writelines(data_lines_list[file_id][start:])
-                else:
-                    f.writelines(data_lines_list[file_id][start:end])
-        start = end
-
-    with open(list_file, "w") as f_data:
-        for i in range(nj):
-            path = ""
-            for file_name in file_names:
-                path = path + " " + os.path.join(split_path, str(i + 1), file_name)
-            f_data.write(path + "\n")
-
-
-def prepare_data(args, distributed_option):
-    data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
-    data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
-    file_names = args.data_file_names.split(",")
-    batch_type = args.dataset_conf["batch_conf"]["batch_type"]
-    print("data_names: {}, data_types: {}, file_names: {}".format(data_names, data_types, file_names))
-    assert len(data_names) == len(data_types) == len(file_names)
-    if args.dataset_type == "small":
-        args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "{}_shape".format(data_names[0]))]
-        args.valid_shape_file = [os.path.join(args.data_dir, args.valid_set, "{}_shape".format(data_names[0]))]
-        args.train_data_path_and_name_and_type, args.valid_data_path_and_name_and_type = [], []
-        for file_name, data_name, data_type in zip(file_names, data_names, data_types):
-            args.train_data_path_and_name_and_type.append(
-                ["{}/{}/{}".format(args.data_dir, args.train_set, file_name), data_name, data_type])
-            args.valid_data_path_and_name_and_type.append(
-                ["{}/{}/{}".format(args.data_dir, args.valid_set, file_name), data_name, data_type])
-        if os.path.exists(args.train_shape_file[0]):
-            assert os.path.exists(args.valid_shape_file[0])
-            print('shape file for small dataset already exists.')
-            return
-    else:
-        concat_data_name = "_".join(data_names)
-        args.train_data_file = os.path.join(args.data_dir, args.train_set, "{}_data.list".format(concat_data_name))
-        args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "{}_data.list".format(concat_data_name))
-        if os.path.exists(args.train_data_file):
-            assert os.path.exists(args.valid_data_file)
-            print('data list for large dataset already exists.')
-            return
-
-    distributed = distributed_option.distributed
-    if not distributed or distributed_option.dist_rank == 0:
-        if hasattr(args, "filter_input") and args.filter_input:
-            filter_wav_text(args.data_dir, args.train_set)
-            filter_wav_text(args.data_dir, args.valid_set)
-
-        if args.dataset_type == "small" and batch_type != "unsorted":
-            calc_shape(args, args.train_set)
-            calc_shape(args, args.valid_set)
-
-        if args.dataset_type == "large":
-            generate_data_list(args, args.data_dir, args.train_set)
-            generate_data_list(args, args.data_dir, args.valid_set)
-
-    if distributed:
-        dist.barrier()

--
Gitblit v1.9.1