From 1233c0d3ff9cf7fd6131862e7d0b208d3981f6da Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 15 一月 2024 20:34:47 +0800
Subject: [PATCH] code update

---
 funasr/download/runtime_sdk_download_tool.py            |   76 
 funasr/models/scama/chunk_utilis.py                     |  644 ++++++------
 funasr/models/branchformer/model.py                     |   14 
 funasr/train_utils/trainer.py                           |  444 ++++----
 funasr/models/sanm/model.py                             |   14 
 funasr/download/download_from_hub.py                    |  193 ++--
 funasr/models/conformer/model.py                        |   16 
 funasr/models/scama/utils.py                            |  107 +-
 funasr/utils/load_utils.py                              |  170 +-
 runtime/python/utils/test_cer.py                        |   28 
 funasr/datasets/audio_datasets/index_ds.py              |  108 +-
 funasr/models/e_branchformer/model.py                   |   14 
 funasr/models/uniasr/e2e_uni_asr.py                     |   24 
 funasr/bin/inference.py                                 |    3 
 runtime/python/utils/test_rtf.py                        |   28 
 funasr/utils/vad_utils.py                               |   50 
 examples/industrial_data_pretraining/paraformer/demo.py |    4 
 funasr/bin/train.py                                     |  294 +++---
 funasr/tokenizer/abs_tokenizer.py                       |  178 +-
 funasr/schedulers/__init__.py                           |   24 
 funasr/optimizers/__init__.py                           |   22 
 runtime/python/utils/test_rtf_gpu.py                    |   28 
 funasr/datasets/audio_datasets/samplers.py              |  141 +-
 funasr/datasets/audio_datasets/datasets.py              |  171 +-
 24 files changed, 1,391 insertions(+), 1,404 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index 20f0f64..6dbe33d 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -18,5 +18,5 @@
 fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
 
 for batch_idx, fbank_dict in enumerate(fbanks):
-	res = model(**fbank_dict)
-	print(res)
\ No newline at end of file
+    res = model(**fbank_dict)
+    print(res)
\ No newline at end of file
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index ca8771d..7368d16 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -309,10 +309,7 @@
             if not len(sorted_data):
                 logging.info("decoding, utt: {}, empty speech".format(key))
                 continue
-            
 
-            # if kwargs["device"] == "cpu":
-            #     batch_size = 0
             if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
                 batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
             
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 878eb24..0881cb2 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,178 +1,180 @@
-import argparse
-import logging
 import os
 import sys
-from io import BytesIO
-from collections.abc import Sequence
 import torch
 import hydra
+import logging
+import argparse
+from io import BytesIO
+import torch.distributed as dist
+from collections.abc import Sequence
 from omegaconf import DictConfig, OmegaConf
-from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.models.lora.utils import mark_only_lora_as_trainable
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+
+from funasr.register import tables
 from funasr.optimizers import optim_classes
+from funasr.train_utils.trainer import Trainer
 from funasr.schedulers import scheduler_classes
-from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.train_utils.initialize import initialize
+from funasr.download.download_from_hub import download_model
+from funasr.models.lora.utils import mark_only_lora_as_trainable
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
 # from funasr.tokenizer.build_tokenizer import build_tokenizer
 # from funasr.tokenizer.token_id_converter import TokenIDConverter
 # from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.train_utils.trainer import Trainer
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from funasr.download.download_from_hub import download_model
-from funasr.register import tables
+
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
-	if kwargs.get("debug", False):
-		import pdb; pdb.set_trace()
+    if kwargs.get("debug", False):
+        import pdb; pdb.set_trace()
 
-	assert "model" in kwargs
-	if "model_conf" not in kwargs:
-		logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
-		kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
-	
+    assert "model" in kwargs
+    if "model_conf" not in kwargs:
+        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
+    
 
-	main(**kwargs)
+    main(**kwargs)
 
 
 def main(**kwargs):
-	# preprocess_config(kwargs)
-	# import pdb; pdb.set_trace()
-	# set random seed
-	tables.print()
-	set_all_random_seed(kwargs.get("seed", 0))
-	torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
-	torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
-	torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
-	
-	local_rank = int(os.environ.get('LOCAL_RANK', 0))
-	# Check if we are using DDP or FSDP
-	use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
-	use_fsdp = kwargs.get("use_fsdp", None)
-	if use_ddp or use_fsdp:
-		dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
-		torch.cuda.set_device(local_rank)
-	
-	# save config.yaml
-	if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
-		os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
-		yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
-		OmegaConf.save(config=kwargs, f=yaml_file)
-		logging.info("config.yaml is saved to: %s", yaml_file)
+    # preprocess_config(kwargs)
+    # import pdb; pdb.set_trace()
+    # set random seed
+    tables.print()
+    set_all_random_seed(kwargs.get("seed", 0))
+    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
+    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
+    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+    
+    local_rank = int(os.environ.get('LOCAL_RANK', 0))
+    # Check if we are using DDP or FSDP
+    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
+    use_fsdp = kwargs.get("use_fsdp", None)
+    if use_ddp or use_fsdp:
+        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
+        torch.cuda.set_device(local_rank)
+    
+    # save config.yaml
+    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+        OmegaConf.save(config=kwargs, f=yaml_file)
+        logging.info("config.yaml is saved to: %s", yaml_file)
 
-	tokenizer = kwargs.get("tokenizer", None)
-	if tokenizer is not None:
-		tokenizer_class = tables.tokenizer_classes.get(tokenizer)
-		tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
-		kwargs["tokenizer"] = tokenizer
-	
-	# build frontend if frontend is none None
-	frontend = kwargs.get("frontend", None)
-	if frontend is not None:
-		frontend_class = tables.frontend_classes.get(frontend)
-		frontend = frontend_class(**kwargs["frontend_conf"])
-		kwargs["frontend"] = frontend
-		kwargs["input_size"] = frontend.output_size()
-	
-	# import pdb;
-	# pdb.set_trace()
-	# build model
-	model_class = tables.model_classes.get(kwargs["model"])
-	model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+    tokenizer = kwargs.get("tokenizer", None)
+    if tokenizer is not None:
+        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+        kwargs["tokenizer"] = tokenizer
+    
+    # build frontend if frontend is none None
+    frontend = kwargs.get("frontend", None)
+    if frontend is not None:
+        frontend_class = tables.frontend_classes.get(frontend)
+        frontend = frontend_class(**kwargs["frontend_conf"])
+        kwargs["frontend"] = frontend
+        kwargs["input_size"] = frontend.output_size()
+    
+    # import pdb;
+    # pdb.set_trace()
+    # build model
+    model_class = tables.model_classes.get(kwargs["model"])
+    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
 
 
 
-	# init_param
-	init_param = kwargs.get("init_param", None)
-	if init_param is not None:
-		if not isinstance(init_param, (list, tuple)):
-			init_param = (init_param,)
-		logging.info("init_param is not None: %s", init_param)
-		for p in init_param:
-			logging.info(f"Loading pretrained params from {p}")
-			load_pretrained_model(
-				model=model,
-				init_param=p,
-				ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
-				oss_bucket=kwargs.get("oss_bucket", None),
-			)
-	else:
-		initialize(model, kwargs.get("init", "kaiming_normal"))
+    # init_param
+    init_param = kwargs.get("init_param", None)
+    if init_param is not None:
+        if not isinstance(init_param, (list, tuple)):
+            init_param = (init_param,)
+        logging.info("init_param is not None: %s", init_param)
+        for p in init_param:
+            logging.info(f"Loading pretrained params from {p}")
+            load_pretrained_model(
+                model=model,
+                init_param=p,
+                ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+                oss_bucket=kwargs.get("oss_bucket", None),
+            )
+    else:
+        initialize(model, kwargs.get("init", "kaiming_normal"))
 
 
-	# freeze_param
-	freeze_param = kwargs.get("freeze_param", None)
-	if freeze_param is not None:
-		freeze_param = eval(freeze_param)
-		if isinstance(freeze_param, Sequence):
-			freeze_param = (freeze_param,)
-		logging.info("freeze_param is not None: %s", freeze_param)
-		for t in freeze_param:
-			for k, p in model.named_parameters():
-				if k.startswith(t + ".") or k == t:
-					logging.info(f"Setting {k}.requires_grad = False")
-					p.requires_grad = False
-	
+    # freeze_param
+    freeze_param = kwargs.get("freeze_param", None)
+    if freeze_param is not None:
+        freeze_param = eval(freeze_param)
+        if isinstance(freeze_param, Sequence):
+            freeze_param = (freeze_param,)
+        logging.info("freeze_param is not None: %s", freeze_param)
+        for t in freeze_param:
+            for k, p in model.named_parameters():
+                if k.startswith(t + ".") or k == t:
+                    logging.info(f"Setting {k}.requires_grad = False")
+                    p.requires_grad = False
+    
 
-	if use_ddp:
-		model = model.cuda(local_rank)
-		model = DDP(model, device_ids=[local_rank],
-		            find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
-	elif use_fsdp:
-		model = FSDP(model).cuda(local_rank)
-	else:
-		model = model.to(device=kwargs.get("device", "cuda"))
-		
-		
-	# optim
-	optim = kwargs.get("optim", "adam")
-	assert optim in optim_classes
-	optim_class = optim_classes.get(optim)
-	optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
-	
-	# scheduler
-	scheduler = kwargs.get("scheduler", "warmuplr")
-	assert scheduler in scheduler_classes
-	scheduler_class = scheduler_classes.get(scheduler)
-	scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
+    if use_ddp:
+        model = model.cuda(local_rank)
+        model = DDP(model, device_ids=[local_rank],
+                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
+    elif use_fsdp:
+        model = FSDP(model).cuda(local_rank)
+    else:
+        model = model.to(device=kwargs.get("device", "cuda"))
+        
+        
+    # optim
+    optim = kwargs.get("optim", "adam")
+    assert optim in optim_classes
+    optim_class = optim_classes.get(optim)
+    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
+    
+    # scheduler
+    scheduler = kwargs.get("scheduler", "warmuplr")
+    assert scheduler in scheduler_classes
+    scheduler_class = scheduler_classes.get(scheduler)
+    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
 
-	# import pdb;
-	# pdb.set_trace()
-	# dataset
-	dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-	dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
+    # import pdb;
+    # pdb.set_trace()
+    # dataset
+    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
+    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
 
-	# dataloader
-	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-	batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
-	if batch_sampler is not None:
-		batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
-	dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
-	                                            collate_fn=dataset_tr.collator,
-	                                            batch_sampler=batch_sampler,
-	                                            num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
-	                                            pin_memory=True)
-	
+    # dataloader
+    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+    if batch_sampler is not None:
+        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
+                                                collate_fn=dataset_tr.collator,
+                                                batch_sampler=batch_sampler,
+                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
+                                                pin_memory=True)
+    
 
-	trainer = Trainer(
-	    model=model,
-	    optim=optim,
-	    scheduler=scheduler,
-	    dataloader_train=dataloader_tr,
-	    dataloader_val=None,
-		local_rank=local_rank,
-		use_ddp=use_ddp,
-		use_fsdp=use_fsdp,
-		**kwargs.get("train_conf"),
-	)
-	trainer.run()
-	
-	if use_ddp or use_fsdp:
-		torch.distributed.destroy_process_group()
+    trainer = Trainer(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        dataloader_train=dataloader_tr,
+        dataloader_val=None,
+        local_rank=local_rank,
+        use_ddp=use_ddp,
+        use_fsdp=use_fsdp,
+        **kwargs.get("train_conf"),
+    )
+    trainer.run()
+    
+    if use_ddp or use_fsdp:
+        torch.distributed.destroy_process_group()
 
-	
+    
 
 if __name__ == "__main__":
-	main_hydra()
\ No newline at end of file
+    main_hydra()
\ No newline at end of file
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 7839ff9..edf127f 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -1,102 +1,93 @@
 import torch
-import json
-import torch.distributed as dist
-import numpy as np
-import kaldiio
-import librosa
-import torchaudio
-import time
-import logging
 
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank
+
 
 @tables.register("dataset_classes", "AudioDataset")
 class AudioDataset(torch.utils.data.Dataset):
-	"""
-	AudioDataset
-	"""
-	def __init__(self,
-	             path,
-	             index_ds: str = None,
-	             frontend=None,
-	             tokenizer=None,
-	             int_pad_value: int = -1,
-	             float_pad_value: float = 0.0,
-	              **kwargs):
-		super().__init__()
-		index_ds_class = tables.index_ds_classes.get(index_ds)
-		self.index_ds = index_ds_class(path)
-		preprocessor_speech = kwargs.get("preprocessor_speech", None)
-		if preprocessor_speech:
-			preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
-			preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
-		self.preprocessor_speech = preprocessor_speech
-		preprocessor_text = kwargs.get("preprocessor_text", None)
-		if preprocessor_text:
-			preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
-			preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
-		self.preprocessor_text = preprocessor_text
-		
-		self.frontend = frontend
-		self.fs = 16000 if frontend is None else frontend.fs
-		self.data_type = "sound"
-		self.tokenizer = tokenizer
+    """
+    AudioDataset
+    """
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                  **kwargs):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+        self.preprocessor_text = preprocessor_text
+        
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
 
-		self.int_pad_value = int_pad_value
-		self.float_pad_value = float_pad_value
-	
-	def get_source_len(self, index):
-		item = self.index_ds[index]
-		return self.index_ds.get_source_len(item)
-	
-	def get_target_len(self, index):
-		item = self.index_ds[index]
-		return self.index_ds.get_target_len(item)
-	
-	def __len__(self):
-		return len(self.index_ds)
-	
-	def __getitem__(self, index):
-		item = self.index_ds[index]
-		# import pdb;
-		# pdb.set_trace()
-		source = item["source"]
-		data_src = load_audio(source, fs=self.fs)
-		if self.preprocessor_speech:
-			data_src = self.preprocessor_speech(data_src)
-		speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
+        self.int_pad_value = int_pad_value
+        self.float_pad_value = float_pad_value
+    
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+    
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+    
+    def __len__(self):
+        return len(self.index_ds)
+    
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src)
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
 
-		target = item["target"]
-		if self.preprocessor_text:
-			target = self.preprocessor_text(target)
-		ids = self.tokenizer.encode(target)
-		ids_lengths = len(ids)
-		text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+        ids = self.tokenizer.encode(target)
+        ids_lengths = len(ids)
+        text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
 
-		return {"speech": speech[0, :, :],
-		        "speech_lengths": speech_lengths,
-		        "text": text,
-		        "text_lengths": text_lengths,
-		        }
-	
-	
-	def collator(self, samples: list=None):
+        return {"speech": speech[0, :, :],
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                }
+    
+    
+    def collator(self, samples: list=None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
 
+        for key, data_list in outputs.items():
+            if data_list[0].dtype == torch.int64:
 
-		outputs = {}
-		for sample in samples:
-			for key in sample.keys():
-				if key not in outputs:
-					outputs[key] = []
-				outputs[key].append(sample[key])
-
-		for key, data_list in outputs.items():
-			if data_list[0].dtype == torch.int64:
-
-				pad_value = self.int_pad_value
-			else:
-				pad_value = self.float_pad_value
-			outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
-		return outputs
+                pad_value = self.int_pad_value
+            else:
+                pad_value = self.float_pad_value
+            outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+        return outputs
 
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 79bb26e..8e5b05c 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -1,64 +1,64 @@
-import torch
 import json
-import torch.distributed as dist
-import time
+import torch
 import logging
+import torch.distributed as dist
 
 from funasr.register import tables
 
+
 @tables.register("index_ds_classes", "IndexDSJsonl")
 class IndexDSJsonl(torch.utils.data.Dataset):
-	
-	def __init__(self, path):
-		super().__init__()
-		
-		contents = []
-		with open(path, encoding='utf-8') as fin:
-			for line in fin:
-				data = json.loads(line.strip())
-				if "text" in data:  # for sft
-					self.contents.append(data['text'])
-				if "source" in data:  # for speech lab pretrain
-					prompt = data["prompt"]
-					source = data["source"]
-					target = data["target"]
-					source_len = data["source_len"]
-					target_len = data["target_len"]
+    
+    def __init__(self, path):
+        super().__init__()
+        
+        contents = []
+        with open(path, encoding='utf-8') as fin:
+            for line in fin:
+                data = json.loads(line.strip())
+                if "text" in data:  # for sft
+                    self.contents.append(data['text'])
+                if "source" in data:  # for speech lab pretrain
+                    prompt = data["prompt"]
+                    source = data["source"]
+                    target = data["target"]
+                    source_len = data["source_len"]
+                    target_len = data["target_len"]
 
-					contents.append({"source": source,
-					                 "prompt": prompt,
-					                 "target": target,
-					                 "source_len": source_len,
-					                 "target_len": target_len,
-					                 }
-					                )
-		
-		self.contents = []
-		total_num = len(contents)
-		try:
-			rank = dist.get_rank()
-			world_size = dist.get_world_size()
-		except:
-			rank = 0
-			world_size = 1
-			logging.warning("distributed is not initialized, only single shard")
-		num_per_rank = total_num // world_size
-		
-		# rank = 0
-		# import ipdb; ipdb.set_trace()
-		self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
-	
-		logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
+                    contents.append({"source": source,
+                                     "prompt": prompt,
+                                     "target": target,
+                                     "source_len": source_len,
+                                     "target_len": target_len,
+                                     }
+                                    )
+        
+        self.contents = []
+        total_num = len(contents)
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+            logging.warning("distributed is not initialized, only single shard")
+        num_per_rank = total_num // world_size
+        
+        # rank = 0
+        # import ipdb; ipdb.set_trace()
+        self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
+    
+        logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
 
-	def __len__(self):
-		return len(self.contents)
-	
-	def __getitem__(self, index):
-		return self.contents[index]
-	
-	def get_source_len(self, data_dict):
-		return data_dict["source_len"]
+    def __len__(self):
+        return len(self.contents)
+    
+    def __getitem__(self, index):
+        return self.contents[index]
+    
+    def get_source_len(self, data_dict):
+        return data_dict["source_len"]
 
-	def get_target_len(self, data_dict):
-		
-		return data_dict["target_len"] if "target_len" in data_dict else 0
+    def get_target_len(self, data_dict):
+        
+        return data_dict["target_len"] if "target_len" in data_dict else 0
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index 9c87245..bc71b28 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -1,5 +1,4 @@
 import torch
-
 import numpy as np
 
 from funasr.register import tables
@@ -7,74 +6,74 @@
 
 @tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
 class BatchSampler(torch.utils.data.BatchSampler):
-	
-	def __init__(self, dataset,
-	             batch_type: str = "example",
-	             batch_size: int = 100,
-	             buffer_size: int = 30,
-	             drop_last: bool = False,
-	             shuffle: bool = True,
-	             **kwargs):
-		
-		self.drop_last = drop_last
-		self.pre_idx = -1
-		self.dataset = dataset
-		self.total_samples = len(dataset)
-		self.batch_type = batch_type
-		self.batch_size = batch_size
-		self.buffer_size = buffer_size
-		self.max_token_length = kwargs.get("max_token_length", 5000)
-		self.shuffle_idx = np.arange(self.total_samples)
-		self.shuffle = shuffle
-	
-	def __len__(self):
-		return self.total_samples
-	
-	def set_epoch(self, epoch):
-		np.random.seed(epoch)
-	
-	def __iter__(self):
-		
-		if self.shuffle:
-			np.random.shuffle(self.shuffle_idx)
-		
-		batch = []
-		max_token = 0
-		num_sample = 0
-		
-		iter_num = (self.total_samples - 1) // self.buffer_size + 1
-		# print("iter_num: ", iter_num)
-		for iter in range(self.pre_idx + 1, iter_num):
-			datalen_with_index = []
-			for i in range(self.buffer_size):
-				idx = iter * self.buffer_size + i
-				if idx >= self.total_samples:
-					continue
-				
-				idx_map = self.shuffle_idx[idx]
-				# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-				sample_len_cur = self.dataset.get_source_len(idx_map) + \
-				                 self.dataset.get_target_len(idx_map)
-				
-				datalen_with_index.append([idx, sample_len_cur])
-			
-			datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
-			for item in datalen_with_index_sort:
-				idx, sample_len_cur_raw = item
-				if sample_len_cur_raw > self.max_token_length:
-					continue
-				
-				max_token_cur = max(max_token, sample_len_cur_raw)
-				max_token_padding = 1 + num_sample
-				if self.batch_type == 'length':
-					max_token_padding *= max_token_cur
-				if max_token_padding <= self.batch_size:
-					batch.append(idx)
-					max_token = max_token_cur
-					num_sample += 1
-				else:
-					yield batch
-					batch = [idx]
-					max_token = sample_len_cur_raw
-					num_sample = 1
+    
+    def __init__(self, dataset,
+                 batch_type: str = "example",
+                 batch_size: int = 100,
+                 buffer_size: int = 30,
+                 drop_last: bool = False,
+                 shuffle: bool = True,
+                 **kwargs):
+        
+        self.drop_last = drop_last
+        self.pre_idx = -1
+        self.dataset = dataset
+        self.total_samples = len(dataset)
+        self.batch_type = batch_type
+        self.batch_size = batch_size
+        self.buffer_size = buffer_size
+        self.max_token_length = kwargs.get("max_token_length", 5000)
+        self.shuffle_idx = np.arange(self.total_samples)
+        self.shuffle = shuffle
+    
+    def __len__(self):
+        return self.total_samples
+    
+    def set_epoch(self, epoch):
+        np.random.seed(epoch)
+    
+    def __iter__(self):
+        
+        if self.shuffle:
+            np.random.shuffle(self.shuffle_idx)
+        
+        batch = []
+        max_token = 0
+        num_sample = 0
+        
+        iter_num = (self.total_samples - 1) // self.buffer_size + 1
+        # print("iter_num: ", iter_num)
+        for iter in range(self.pre_idx + 1, iter_num):
+            datalen_with_index = []
+            for i in range(self.buffer_size):
+                idx = iter * self.buffer_size + i
+                if idx >= self.total_samples:
+                    continue
+                
+                idx_map = self.shuffle_idx[idx]
+                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+                sample_len_cur = self.dataset.get_source_len(idx_map) + \
+                                 self.dataset.get_target_len(idx_map)
+                
+                datalen_with_index.append([idx, sample_len_cur])
+            
+            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
+            for item in datalen_with_index_sort:
+                idx, sample_len_cur_raw = item
+                if sample_len_cur_raw > self.max_token_length:
+                    continue
+                
+                max_token_cur = max(max_token, sample_len_cur_raw)
+                max_token_padding = 1 + num_sample
+                if self.batch_type == 'length':
+                    max_token_padding *= max_token_cur
+                if max_token_padding <= self.batch_size:
+                    batch.append(idx)
+                    max_token = max_token_cur
+                    num_sample += 1
+                else:
+                    yield batch
+                    batch = [idx]
+                    max_token = sample_len_cur_raw
+                    num_sample = 1
 
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 57e8c41..cde4b7d 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -1,110 +1,111 @@
-import json
 import os
+import json
 from omegaconf import OmegaConf
-import torch
+
 from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
 
+
 def download_model(**kwargs):
-	model_hub = kwargs.get("model_hub", "ms")
-	if model_hub == "ms":
-		kwargs = download_from_ms(**kwargs)
-	
-	return kwargs
+    model_hub = kwargs.get("model_hub", "ms")
+    if model_hub == "ms":
+        kwargs = download_from_ms(**kwargs)
+    
+    return kwargs
 
 def download_from_ms(**kwargs):
-	model_or_path = kwargs.get("model")
-	if model_or_path in name_maps_ms:
-		model_or_path = name_maps_ms[model_or_path]
-	model_revision = kwargs.get("model_revision")
-	if not os.path.exists(model_or_path):
-		model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
-	kwargs["model_path"] = model_or_path
-	
-	config = os.path.join(model_or_path, "config.yaml")
-	if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
-		
-		config = OmegaConf.load(config)
-		kwargs = OmegaConf.merge(config, kwargs)
-		init_param = os.path.join(model_or_path, "model.pb")
-		kwargs["init_param"] = init_param
-		if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
-			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
-		if os.path.exists(os.path.join(model_or_path, "tokens.json")):
-			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
-		if os.path.exists(os.path.join(model_or_path, "seg_dict")):
-			kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
-		if os.path.exists(os.path.join(model_or_path, "bpe.model")):
-			kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
-		kwargs["model"] = config["model"]
-		if os.path.exists(os.path.join(model_or_path, "am.mvn")):
-			kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
-		if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
-			kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
-	elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
-		with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
-			conf_json = json.load(f)
-			cfg = {}
-			add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
-			cfg.update(kwargs)
-			config = OmegaConf.load(cfg["config"])
-			kwargs = OmegaConf.merge(config, cfg)
-		kwargs["model"] = config["model"]
-	return OmegaConf.to_container(kwargs, resolve=True)
+    model_or_path = kwargs.get("model")
+    if model_or_path in name_maps_ms:
+        model_or_path = name_maps_ms[model_or_path]
+    model_revision = kwargs.get("model_revision")
+    if not os.path.exists(model_or_path):
+        model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
+    kwargs["model_path"] = model_or_path
+    
+    config = os.path.join(model_or_path, "config.yaml")
+    if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
+        
+        config = OmegaConf.load(config)
+        kwargs = OmegaConf.merge(config, kwargs)
+        init_param = os.path.join(model_or_path, "model.pb")
+        kwargs["init_param"] = init_param
+        if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+        if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+        if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+            kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+        if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+            kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+        kwargs["model"] = config["model"]
+        if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+            kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+        if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+            kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+    elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
+        with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
+            conf_json = json.load(f)
+            cfg = {}
+            add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+            cfg.update(kwargs)
+            config = OmegaConf.load(cfg["config"])
+            kwargs = OmegaConf.merge(config, cfg)
+        kwargs["model"] = config["model"]
+    return OmegaConf.to_container(kwargs, resolve=True)
 
 def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
-	
-	if isinstance(file_path_metas, dict):
-		for k, v in file_path_metas.items():
-			if isinstance(v, str):
-				p = os.path.join(model_or_path, v)
-				if os.path.exists(p):
-					cfg[k] = p
-			elif isinstance(v, dict):
-				if k not in cfg:
-					cfg[k] = {}
-				return add_file_root_path(model_or_path, v, cfg[k])
-	
-	return cfg
+    
+    if isinstance(file_path_metas, dict):
+        for k, v in file_path_metas.items():
+            if isinstance(v, str):
+                p = os.path.join(model_or_path, v)
+                if os.path.exists(p):
+                    cfg[k] = p
+            elif isinstance(v, dict):
+                if k not in cfg:
+                    cfg[k] = {}
+                return add_file_root_path(model_or_path, v, cfg[k])
+    
+    return cfg
 
 
 def get_or_download_model_dir(
-		model,
-		model_revision=None,
-		is_training=False,
-		check_latest=True,
-	):
-	""" Get local model directory or download model if necessary.
+        model,
+        model_revision=None,
+        is_training=False,
+        check_latest=True,
+    ):
+    """ Get local model directory or download model if necessary.
 
-	Args:
-		model (str): model id or path to local model directory.
-		model_revision  (str, optional): model version number.
-		:param is_training:
-	"""
-	from modelscope.hub.check_model import check_local_model_is_latest
-	from modelscope.hub.snapshot_download import snapshot_download
+    Args:
+        model (str): model id or path to local model directory.
+        model_revision  (str, optional): model version number.
+        :param is_training:
+    """
+    from modelscope.hub.check_model import check_local_model_is_latest
+    from modelscope.hub.snapshot_download import snapshot_download
 
-	from modelscope.utils.constant import Invoke, ThirdParty
-	
-	key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
-	
-	if os.path.exists(model) and check_latest:
-		model_cache_dir = model if os.path.isdir(
-			model) else os.path.dirname(model)
-		try:
-			check_local_model_is_latest(
-				model_cache_dir,
-				user_agent={
-					Invoke.KEY: key,
-					ThirdParty.KEY: "funasr"
-				})
-		except:
-			print("could not check the latest version")
-	else:
-		model_cache_dir = snapshot_download(
-			model,
-			revision=model_revision,
-			user_agent={
-				Invoke.KEY: key,
-				ThirdParty.KEY: "funasr"
-			})
-	return model_cache_dir
\ No newline at end of file
+    from modelscope.utils.constant import Invoke, ThirdParty
+    
+    key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
+    
+    if os.path.exists(model) and check_latest:
+        model_cache_dir = model if os.path.isdir(
+            model) else os.path.dirname(model)
+        try:
+            check_local_model_is_latest(
+                model_cache_dir,
+                user_agent={
+                    Invoke.KEY: key,
+                    ThirdParty.KEY: "funasr"
+                })
+        except:
+            print("could not check the latest version")
+    else:
+        model_cache_dir = snapshot_download(
+            model,
+            revision=model_revision,
+            user_agent={
+                Invoke.KEY: key,
+                ThirdParty.KEY: "funasr"
+            })
+    return model_cache_dir
\ No newline at end of file
diff --git a/funasr/download/runtime_sdk_download_tool.py b/funasr/download/runtime_sdk_download_tool.py
index 92416f4..1981347 100644
--- a/funasr/download/runtime_sdk_download_tool.py
+++ b/funasr/download/runtime_sdk_download_tool.py
@@ -1,45 +1,47 @@
-from pathlib import Path
 import os
 import argparse
+from pathlib import Path
+
 from funasr.utils.types import str2bool
 
+
 def main():
-	parser = argparse.ArgumentParser()
-	parser.add_argument('--model-name', type=str, required=True)
-	parser.add_argument('--export-dir', type=str, required=True)
-	parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
-	parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
-	parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
-	parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
-	parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
-	parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
-	parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
-	parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
-	args = parser.parse_args()
-	
-	model_dir = args.model_name
-	if not Path(args.model_name).exists():
-		from modelscope.hub.snapshot_download import snapshot_download
-		try:
-			model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
-		except:
-			raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
-				(model_dir)
-	if args.export:
-		model_file = os.path.join(model_dir, 'model.onnx')
-		if args.quantize:
-			model_file = os.path.join(model_dir, 'model_quant.onnx')
-		if not os.path.exists(model_file):
-			print(".onnx is not exist, begin to export onnx")
-			from funasr.bin.export_model import ModelExport
-			export_model = ModelExport(
-				cache_dir=args.export_dir,
-				onnx=True,
-				device="cpu",
-				quant=args.quantize,
-			)
-			export_model.export(model_dir)
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model-name', type=str, required=True)
+    parser.add_argument('--export-dir', type=str, required=True)
+    parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
+    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+    parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+    parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
+    parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+    args = parser.parse_args()
+    
+    model_dir = args.model_name
+    if not Path(args.model_name).exists():
+        from modelscope.hub.snapshot_download import snapshot_download
+        try:
+            model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
+        except:
+            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
+                (model_dir)
+    if args.export:
+        model_file = os.path.join(model_dir, 'model.onnx')
+        if args.quantize:
+            model_file = os.path.join(model_dir, 'model_quant.onnx')
+        if not os.path.exists(model_file):
+            print(".onnx is not exist, begin to export onnx")
+            from funasr.bin.export_model import ModelExport
+            export_model = ModelExport(
+                cache_dir=args.export_dir,
+                onnx=True,
+                device="cpu",
+                quant=args.quantize,
+            )
+            export_model.export(model_dir)
 
 
 if __name__ == "__main__":
-	main()
\ No newline at end of file
+    main()
\ No newline at end of file
diff --git a/funasr/models/branchformer/model.py b/funasr/models/branchformer/model.py
index 53f254d..7fa99b3 100644
--- a/funasr/models/branchformer/model.py
+++ b/funasr/models/branchformer/model.py
@@ -5,12 +5,12 @@
 
 @tables.register("model_classes", "Branchformer")
 class Branchformer(Transformer):
-	"""CTC-attention hybrid Encoder-Decoder model"""
+    """CTC-attention hybrid Encoder-Decoder model"""
 
-	def __init__(
-		self,
-		*args,
-		**kwargs,
-	):
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
 
-		super().__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
diff --git a/funasr/models/conformer/model.py b/funasr/models/conformer/model.py
index 2c26753..171014b 100644
--- a/funasr/models/conformer/model.py
+++ b/funasr/models/conformer/model.py
@@ -7,13 +7,13 @@
 
 @tables.register("model_classes", "Conformer")
 class Conformer(Transformer):
-	"""CTC-attention hybrid Encoder-Decoder model"""
+    """CTC-attention hybrid Encoder-Decoder model"""
 
-	
-	def __init__(
-		self,
-		*args,
-		**kwargs,
-	):
+    
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
 
-		super().__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
diff --git a/funasr/models/e_branchformer/model.py b/funasr/models/e_branchformer/model.py
index 4ffeb3e..14c8c4d 100644
--- a/funasr/models/e_branchformer/model.py
+++ b/funasr/models/e_branchformer/model.py
@@ -5,12 +5,12 @@
 
 @tables.register("model_classes", "EBranchformer")
 class EBranchformer(Transformer):
-	"""CTC-attention hybrid Encoder-Decoder model"""
+    """CTC-attention hybrid Encoder-Decoder model"""
 
-	def __init__(
-		self,
-		*args,
-		**kwargs,
-	):
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
 
-		super().__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
diff --git a/funasr/models/sanm/model.py b/funasr/models/sanm/model.py
index d51478f..4dc8825 100644
--- a/funasr/models/sanm/model.py
+++ b/funasr/models/sanm/model.py
@@ -7,12 +7,12 @@
 
 @tables.register("model_classes", "SANM")
 class SANM(Transformer):
-	"""CTC-attention hybrid Encoder-Decoder model"""
+    """CTC-attention hybrid Encoder-Decoder model"""
 
-	def __init__(
-		self,
-		*args,
-		**kwargs,
-	):
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
 
-		super().__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
diff --git a/funasr/models/scama/chunk_utilis.py b/funasr/models/scama/chunk_utilis.py
index e90ab62..245d282 100644
--- a/funasr/models/scama/chunk_utilis.py
+++ b/funasr/models/scama/chunk_utilis.py
@@ -1,289 +1,287 @@
-
+import math
 import torch
 import numpy as np
-import math
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-import logging
 import torch.nn.functional as F
-from funasr.models.scama.utils import sequence_mask
 
+from funasr.models.scama.utils import sequence_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 
 class overlap_chunk():
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	San-m: Memory equipped self-attention for end-to-end speech recognition
-	https://arxiv.org/abs/2006.01713
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    San-m: Memory equipped self-attention for end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
 
-	"""
-	def __init__(self,
-		chunk_size: tuple = (16,),
-		stride: tuple = (10,),
-		pad_left: tuple = (0,),
-		encoder_att_look_back_factor: tuple = (1,),
+    """
+    def __init__(self,
+        chunk_size: tuple = (16,),
+        stride: tuple = (10,),
+        pad_left: tuple = (0,),
+        encoder_att_look_back_factor: tuple = (1,),
         shfit_fsmn: int = 0,
         decoder_att_look_back_factor: tuple = (1,),
-	):
+    ):
 
-		pad_left = self.check_chunk_size_args(chunk_size, pad_left)
-		encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
-		decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
-		self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
-			= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
-		self.shfit_fsmn = shfit_fsmn
-		self.x_add_mask = None
-		self.x_rm_mask = None
-		self.x_len = None
-		self.mask_shfit_chunk = None
-		self.mask_chunk_predictor = None
-		self.mask_att_chunk_encoder = None
-		self.mask_shift_att_chunk_decoder = None
-		self.chunk_outs = None
-		self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
-			= None, None, None, None, None
+        pad_left = self.check_chunk_size_args(chunk_size, pad_left)
+        encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
+        decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
+        self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
+            = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
+        self.shfit_fsmn = shfit_fsmn
+        self.x_add_mask = None
+        self.x_rm_mask = None
+        self.x_len = None
+        self.mask_shfit_chunk = None
+        self.mask_chunk_predictor = None
+        self.mask_att_chunk_encoder = None
+        self.mask_shift_att_chunk_decoder = None
+        self.chunk_outs = None
+        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
+            = None, None, None, None, None
 
-	def check_chunk_size_args(self, chunk_size, x):
-		if len(x) < len(chunk_size):
-			x = [x[0] for i in chunk_size]
-		return x
+    def check_chunk_size_args(self, chunk_size, x):
+        if len(x) < len(chunk_size):
+            x = [x[0] for i in chunk_size]
+        return x
 
-	def get_chunk_size(self,
-		ind: int = 0
-	):
-		# with torch.no_grad:
-		chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
-			self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
-		self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
-			= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
-		return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
+    def get_chunk_size(self,
+        ind: int = 0
+    ):
+        # with torch.no_grad:
+        chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
+            self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
+        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
+            = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
+        return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
 
-	def random_choice(self, training=True, decoding_ind=None):
-		chunk_num = len(self.chunk_size)
-		ind = 0
-		if training and chunk_num > 1:
-			ind = torch.randint(0, chunk_num, ()).cpu().item()
-		if not training and decoding_ind is not None:
-			ind = int(decoding_ind)
+    def random_choice(self, training=True, decoding_ind=None):
+        chunk_num = len(self.chunk_size)
+        ind = 0
+        if training and chunk_num > 1:
+            ind = torch.randint(0, chunk_num, ()).cpu().item()
+        if not training and decoding_ind is not None:
+            ind = int(decoding_ind)
 
-		return ind
+        return ind
 
 
 
 
-	def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
+    def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
 
-		with torch.no_grad():
-			x_len = x_len.cpu().numpy()
-			x_len_max = x_len.max()
+        with torch.no_grad():
+            x_len = x_len.cpu().numpy()
+            x_len_max = x_len.max()
 
-			chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
-			shfit_fsmn = self.shfit_fsmn
-			pad_right = chunk_size - stride - pad_left
+            chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
+            shfit_fsmn = self.shfit_fsmn
+            pad_right = chunk_size - stride - pad_left
 
-			chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
-			x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
-			x_len_chunk = x_len_chunk.astype(x_len.dtype)
-			x_len_chunk_max = x_len_chunk.max()
+            chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
+            x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
+            x_len_chunk = x_len_chunk.astype(x_len.dtype)
+            x_len_chunk_max = x_len_chunk.max()
 
-			chunk_num = int(math.ceil(x_len_max/stride))
-			dtype = np.int32
-			max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
-			x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
-			x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
-			mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
-			mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
-			mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
-			mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
-			for chunk_ids in range(chunk_num):
-				# x_mask add
-				fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
-				x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
-				x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
-				x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
-				x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
-				x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
-				x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
-				x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
+            chunk_num = int(math.ceil(x_len_max/stride))
+            dtype = np.int32
+            max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
+            x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
+            x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
+            mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
+            mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
+            mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
+            mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
+            for chunk_ids in range(chunk_num):
+                # x_mask add
+                fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
+                x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
+                x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
+                x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
+                x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
+                x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
+                x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
+                x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
 
-				# x_mask rm
-				fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
-				padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
-				padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
-				x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
-				x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
-				x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
-				x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
-				x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
-				x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
-				x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
+                # x_mask rm
+                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
+                padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
+                padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
+                x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
+                x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
+                x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
+                x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
+                x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
+                x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
+                x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
 
-				# fsmn_padding_mask
-				pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
-				ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
-				mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
-				mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
+                # fsmn_padding_mask
+                pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
+                ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
+                mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
+                mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
 
-				# predictor mask
-				zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
-				ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
-				zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
-				ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
-				mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
-				mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
+                # predictor mask
+                zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
+                ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
+                zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
+                ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
+                mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
+                mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
 
-				# encoder att mask
-				zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
+                # encoder att mask
+                zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
 
-				zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
-				zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
+                zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
+                zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
 
-				encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
-				zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
-				ones_2_mid = np.ones([stride, stride], dtype=dtype)
-				zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
-				zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
-				ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
-				ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
-				ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
+                encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
+                zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+                ones_2_mid = np.ones([stride, stride], dtype=dtype)
+                zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
+                zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
+                ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
+                ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
+                ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
 
-				zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
-				ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
-				ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
+                zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+                ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
+                ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
 
-				zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
-				zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
+                zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
+                zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
 
-				ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
-				mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
-				mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
+                ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
+                mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
+                mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
 
 
-				# decoder fsmn_shift_att_mask
-				zeros_1 = np.zeros([shfit_fsmn, 1])
-				ones_1 = np.ones([chunk_size, 1])
-				mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
-				mask_shift_att_chunk_decoder = np.concatenate(
-					[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
+                # decoder fsmn_shift_att_mask
+                zeros_1 = np.zeros([shfit_fsmn, 1])
+                ones_1 = np.ones([chunk_size, 1])
+                mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
+                mask_shift_att_chunk_decoder = np.concatenate(
+                    [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
 
-			self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
-			self.x_len_chunk = x_len_chunk
-			self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
-			self.x_len = x_len
-			self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
-			self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
-			self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
-			self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
-			self.chunk_outs = (self.x_add_mask,
-		        self.x_len_chunk,
-		        self.x_rm_mask,
-		        self.x_len,
-		        self.mask_shfit_chunk,
-		        self.mask_chunk_predictor,
-		        self.mask_att_chunk_encoder,
-		        self.mask_shift_att_chunk_decoder)
+            self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
+            self.x_len_chunk = x_len_chunk
+            self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
+            self.x_len = x_len
+            self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
+            self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
+            self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
+            self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
+            self.chunk_outs = (self.x_add_mask,
+                self.x_len_chunk,
+                self.x_rm_mask,
+                self.x_len,
+                self.mask_shfit_chunk,
+                self.mask_chunk_predictor,
+                self.mask_att_chunk_encoder,
+                self.mask_shift_att_chunk_decoder)
 
-		return self.chunk_outs
+        return self.chunk_outs
 
 
-	def split_chunk(self, x, x_len, chunk_outs):
-		"""
-		:param x: (b, t, d)
-		:param x_length: (b)
-		:param ind: int
-		:return:
-		"""
-		x = x[:, :x_len.max(), :]
-		b, t, d = x.size()
-		x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
-			x.device)
-		x *= x_len_mask[:, :, None]
+    def split_chunk(self, x, x_len, chunk_outs):
+        """
+        :param x: (b, t, d)
+        :param x_length: (b)
+        :param ind: int
+        :return:
+        """
+        x = x[:, :x_len.max(), :]
+        b, t, d = x.size()
+        x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
+            x.device)
+        x *= x_len_mask[:, :, None]
 
-		x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
-		x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
-		pad = (0, 0, self.pad_left_cur, 0)
-		x = F.pad(x, pad, "constant", 0.0)
-		b, t, d = x.size()
-		x = torch.transpose(x, 1, 0)
-		x = torch.reshape(x, [t, -1])
-		x_chunk = torch.mm(x_add_mask, x)
-		x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
+        x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
+        x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
+        pad = (0, 0, self.pad_left_cur, 0)
+        x = F.pad(x, pad, "constant", 0.0)
+        b, t, d = x.size()
+        x = torch.transpose(x, 1, 0)
+        x = torch.reshape(x, [t, -1])
+        x_chunk = torch.mm(x_add_mask, x)
+        x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
 
-		return x_chunk, x_len_chunk
+        return x_chunk, x_len_chunk
 
-	def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
-		x_chunk = x_chunk[:, :x_len_chunk.max(), :]
-		b, t, d = x_chunk.size()
-		x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
-			x_chunk.device)
-		x_chunk *= x_len_chunk_mask[:, :, None]
+    def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
+        x_chunk = x_chunk[:, :x_len_chunk.max(), :]
+        b, t, d = x_chunk.size()
+        x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
+            x_chunk.device)
+        x_chunk *= x_len_chunk_mask[:, :, None]
 
-		x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
-		x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
-		x_chunk = torch.transpose(x_chunk, 1, 0)
-		x_chunk = torch.reshape(x_chunk, [t, -1])
-		x = torch.mm(x_rm_mask, x_chunk)
-		x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
+        x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
+        x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
+        x_chunk = torch.transpose(x_chunk, 1, 0)
+        x_chunk = torch.reshape(x_chunk, [t, -1])
+        x = torch.mm(x_rm_mask, x_chunk)
+        x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
 
-		return x, x_len
+        return x, x_len
 
-	def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
 
-	def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
 
-	def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
 
 
 def build_scama_mask_for_cross_attention_decoder(
-							  predictor_alignments: torch.Tensor,
+                              predictor_alignments: torch.Tensor,
                               encoder_sequence_length: torch.Tensor,
                               chunk_size: int = 5,
                               encoder_chunk_size: int = 5,
@@ -291,100 +289,100 @@
                               attention_chunk_size: int = 1,
                               attention_chunk_type: str = 'chunk',
                               step=None,
-							  predictor_mask_chunk_hopping: torch.Tensor = None,
-							  decoder_att_look_back_factor: int = 1,
-							  mask_shift_att_chunk_decoder: torch.Tensor = None,
-							  target_length: torch.Tensor = None,
-							  is_training=True,
+                              predictor_mask_chunk_hopping: torch.Tensor = None,
+                              decoder_att_look_back_factor: int = 1,
+                              mask_shift_att_chunk_decoder: torch.Tensor = None,
+                              target_length: torch.Tensor = None,
+                              is_training=True,
                               dtype: torch.dtype = torch.float32):
-	with torch.no_grad():
-		device = predictor_alignments.device
-		batch_size, chunk_num = predictor_alignments.size()
-		maximum_encoder_length = encoder_sequence_length.max().item()
-		int_type = predictor_alignments.dtype
-		if not is_training:
-			target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
-		maximum_target_length = target_length.max()
-		predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
-		predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
-	
-	
-		index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
-		index = torch.cumsum(index, dim=1)
-		index = index[:, :, None].repeat(1, 1, chunk_num)
-	
-		index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
-		index_div_bool_zeros = index_div == 0
-		index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
-	
-		index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
-	
-		index_div_bool_zeros_count *= chunk_size
-		index_div_bool_zeros_count += attention_chunk_center_bias
-		index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
-		index_div_bool_zeros_count_ori = index_div_bool_zeros_count
-	
-		index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
-		max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
-	
-		mask_flip, mask_flip2 = None, None
-		if attention_chunk_size is not None:
-			index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
-			index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
-			index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
-			mask_flip = 1 - index_div_bool_zeros_count_beg_mask
-			attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
-			index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
-	
-			index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
-			index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
-			mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
-	
-		mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
-	
-		if predictor_mask_chunk_hopping is not None:
-				b, k, t = mask.size()
-				predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
-	
-				mask_mask_flip = mask
-				if mask_flip is not None:
-						mask_mask_flip = mask_flip * mask
-	
-				def _fn():
-						mask_sliced = mask[:b, :k, encoder_chunk_size:t]
-						zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
-						mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
-						_, _, tt = predictor_mask_chunk_hopping.size()
-						pad_right_p = max_len_chunk - tt
-						predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
-						masked = mask_sliced * predictor_mask_chunk_hopping_pad
-	
-						mask_true = mask_mask_flip + masked
-						return mask_true
-	
-				mask = _fn() if t > chunk_size else mask_mask_flip
-	
-	
-	
-		if mask_flip2 is not None:
-			mask *= mask_flip2
-	
-		mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
-		mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
-	
-	
-	
-		mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
-		mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
-	
-	
-	
-	
-		if attention_chunk_type == 'full':
-			mask = torch.ones_like(mask).to(device)
-		if mask_shift_att_chunk_decoder is not None:
-			mask = mask * mask_shift_att_chunk_decoder
-		mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
+    with torch.no_grad():
+        device = predictor_alignments.device
+        batch_size, chunk_num = predictor_alignments.size()
+        maximum_encoder_length = encoder_sequence_length.max().item()
+        int_type = predictor_alignments.dtype
+        if not is_training:
+            target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
+        maximum_target_length = target_length.max()
+        predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
+        predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
+    
+    
+        index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
+        index = torch.cumsum(index, dim=1)
+        index = index[:, :, None].repeat(1, 1, chunk_num)
+    
+        index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
+        index_div_bool_zeros = index_div == 0
+        index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
+    
+        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
+    
+        index_div_bool_zeros_count *= chunk_size
+        index_div_bool_zeros_count += attention_chunk_center_bias
+        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
+        index_div_bool_zeros_count_ori = index_div_bool_zeros_count
+    
+        index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
+        max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
+    
+        mask_flip, mask_flip2 = None, None
+        if attention_chunk_size is not None:
+            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
+            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
+            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
+            mask_flip = 1 - index_div_bool_zeros_count_beg_mask
+            attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
+            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
+    
+            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
+            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
+            mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
+    
+        mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
+    
+        if predictor_mask_chunk_hopping is not None:
+                b, k, t = mask.size()
+                predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
+    
+                mask_mask_flip = mask
+                if mask_flip is not None:
+                        mask_mask_flip = mask_flip * mask
+    
+                def _fn():
+                        mask_sliced = mask[:b, :k, encoder_chunk_size:t]
+                        zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
+                        mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
+                        _, _, tt = predictor_mask_chunk_hopping.size()
+                        pad_right_p = max_len_chunk - tt
+                        predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
+                        masked = mask_sliced * predictor_mask_chunk_hopping_pad
+    
+                        mask_true = mask_mask_flip + masked
+                        return mask_true
+    
+                mask = _fn() if t > chunk_size else mask_mask_flip
+    
+    
+    
+        if mask_flip2 is not None:
+            mask *= mask_flip2
+    
+        mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
+        mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
+    
+    
+    
+        mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
+        mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
+    
+    
+    
+    
+        if attention_chunk_type == 'full':
+            mask = torch.ones_like(mask).to(device)
+        if mask_shift_att_chunk_decoder is not None:
+            mask = mask * mask_shift_att_chunk_decoder
+        mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
 
-	return mask
+    return mask
 
diff --git a/funasr/models/scama/utils.py b/funasr/models/scama/utils.py
index 4bb9d4f..8832596 100644
--- a/funasr/models/scama/utils.py
+++ b/funasr/models/scama/utils.py
@@ -1,29 +1,30 @@
 import os
-import torch
-from torch.nn import functional as F
 import yaml
+import torch
 import numpy as np
+from torch.nn import functional as F
+
 
 def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
-	if maxlen is None:
-		maxlen = lengths.max()
-	row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
-	matrix = torch.unsqueeze(lengths, dim=-1)
-	mask = row_vector < matrix
-	mask = mask.detach()
+    if maxlen is None:
+        maxlen = lengths.max()
+    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
+    matrix = torch.unsqueeze(lengths, dim=-1)
+    mask = row_vector < matrix
+    mask = mask.detach()
 
-	return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
 
 def apply_cmvn(inputs, mvn):
-	device = inputs.device
-	dtype = inputs.dtype
-	frame, dim = inputs.shape
-	meams = np.tile(mvn[0:1, :dim], (frame, 1))
-	vars = np.tile(mvn[1:2, :dim], (frame, 1))
-	inputs -= torch.from_numpy(meams).type(dtype).to(device)
-	inputs *= torch.from_numpy(vars).type(dtype).to(device)
+    device = inputs.device
+    dtype = inputs.dtype
+    frame, dim = inputs.shape
+    meams = np.tile(mvn[0:1, :dim], (frame, 1))
+    vars = np.tile(mvn[1:2, :dim], (frame, 1))
+    inputs -= torch.from_numpy(meams).type(dtype).to(device)
+    inputs *= torch.from_numpy(vars).type(dtype).to(device)
 
-	return inputs.type(torch.float32)
+    return inputs.type(torch.float32)
 
 
 
@@ -36,56 +37,56 @@
 
 
 
-	outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
-	outputs *= stoch_layer_coeff
+    outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True)
+    outputs *= stoch_layer_coeff
 
-	input_dim = inputs.size(-1)
-	output_dim = outputs.size(-1)
+    input_dim = inputs.size(-1)
+    output_dim = outputs.size(-1)
 
-	if input_dim == output_dim:
-		outputs += inputs
-	return outputs
+    if input_dim == output_dim:
+        outputs += inputs
+    return outputs
 
 
 def proc_tf_vocab(vocab_path):
-	with open(vocab_path, encoding="utf-8") as f:
-		token_list = [line.rstrip() for line in f]
-		if '<unk>' not in token_list:
-			token_list.append('<unk>')
-	return token_list
+    with open(vocab_path, encoding="utf-8") as f:
+        token_list = [line.rstrip() for line in f]
+        if '<unk>' not in token_list:
+            token_list.append('<unk>')
+    return token_list
 
 
 def gen_config_for_tfmodel(config_path, vocab_path, output_dir):
-	token_list = proc_tf_vocab(vocab_path)
-	with open(config_path, encoding="utf-8") as f:
-		config = yaml.safe_load(f)
-	
-	config['token_list'] = token_list
-	
-	if not os.path.exists(output_dir):
-		os.makedirs(output_dir)
-	
-	with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
-		yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
+    token_list = proc_tf_vocab(vocab_path)
+    with open(config_path, encoding="utf-8") as f:
+        config = yaml.safe_load(f)
+    
+    config['token_list'] = token_list
+    
+    if not os.path.exists(output_dir):
+        os.makedirs(output_dir)
+    
+    with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f:
+        yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False)
 
 
 class NoAliasSafeDumper(yaml.SafeDumper):
-	# Disable anchor/alias in yaml because looks ugly
-	def ignore_aliases(self, data):
-		return True
+    # Disable anchor/alias in yaml because looks ugly
+    def ignore_aliases(self, data):
+        return True
 
 
 def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
-	"""Safe-dump in yaml with no anchor/alias"""
-	return yaml.dump(
-		data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
-	)
+    """Safe-dump in yaml with no anchor/alias"""
+    return yaml.dump(
+        data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
+    )
 
 
 if __name__ == '__main__':
-	import sys
-	
-	config_path = sys.argv[1]
-	vocab_path = sys.argv[2]
-	output_dir = sys.argv[3]
-	gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
+    import sys
+    
+    config_path = sys.argv[1]
+    vocab_path = sys.argv[2]
+    output_dir = sys.argv[3]
+    gen_config_for_tfmodel(config_path, vocab_path, output_dir)
\ No newline at end of file
diff --git a/funasr/models/uniasr/e2e_uni_asr.py b/funasr/models/uniasr/e2e_uni_asr.py
index de7ed29..390d274 100644
--- a/funasr/models/uniasr/e2e_uni_asr.py
+++ b/funasr/models/uniasr/e2e_uni_asr.py
@@ -541,20 +541,20 @@
                         speech_lengths: (Batch, )
         """
         # with autocast(False):
-        # 	# 1. Extract feats
-        # 	feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        #     # 1. Extract feats
+        #     feats, feats_lengths = self._extract_feats(speech, speech_lengths)
         #
-        # 	# 2. Data augmentation
-        # 	if self.specaug is not None and self.training:
-        # 		feats, feats_lengths = self.specaug(feats, feats_lengths)
+        #     # 2. Data augmentation
+        #     if self.specaug is not None and self.training:
+        #         feats, feats_lengths = self.specaug(feats, feats_lengths)
         #
-        # 	# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-        # 	if self.normalize is not None:
-        # 		feats, feats_lengths = self.normalize(feats, feats_lengths)
+        #     # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+        #     if self.normalize is not None:
+        #         feats, feats_lengths = self.normalize(feats, feats_lengths)
 
         # Pre-encoder, e.g. used for raw input data
         # if self.preencoder is not None:
-        # 	feats, feats_lengths = self.preencoder(feats, feats_lengths)
+        #     feats, feats_lengths = self.preencoder(feats, feats_lengths)
         encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
             encoder_out,
             encoder_out_lens,
@@ -584,9 +584,9 @@
 
         # # Post-encoder, e.g. NLU
         # if self.postencoder is not None:
-        # 	encoder_out, encoder_out_lens = self.postencoder(
-        # 		encoder_out, encoder_out_lens
-        # 	)
+        #     encoder_out, encoder_out_lens = self.postencoder(
+        #         encoder_out, encoder_out_lens
+        #     )
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
diff --git a/funasr/optimizers/__init__.py b/funasr/optimizers/__init__.py
index 177f89e..a1a57a5 100644
--- a/funasr/optimizers/__init__.py
+++ b/funasr/optimizers/__init__.py
@@ -3,15 +3,15 @@
 from funasr.optimizers.sgd import SGD
 
 optim_classes = dict(
-	adam=torch.optim.Adam,
-	fairseq_adam=FairseqAdam,
-	adamw=torch.optim.AdamW,
-	sgd=SGD,
-	adadelta=torch.optim.Adadelta,
-	adagrad=torch.optim.Adagrad,
-	adamax=torch.optim.Adamax,
-	asgd=torch.optim.ASGD,
-	lbfgs=torch.optim.LBFGS,
-	rmsprop=torch.optim.RMSprop,
-	rprop=torch.optim.Rprop,
+    adam=torch.optim.Adam,
+    fairseq_adam=FairseqAdam,
+    adamw=torch.optim.AdamW,
+    sgd=SGD,
+    adadelta=torch.optim.Adadelta,
+    adagrad=torch.optim.Adagrad,
+    adamax=torch.optim.Adamax,
+    asgd=torch.optim.ASGD,
+    lbfgs=torch.optim.LBFGS,
+    rmsprop=torch.optim.RMSprop,
+    rprop=torch.optim.Rprop,
 )
\ No newline at end of file
diff --git a/funasr/schedulers/__init__.py b/funasr/schedulers/__init__.py
index cba286a..0d1a578 100644
--- a/funasr/schedulers/__init__.py
+++ b/funasr/schedulers/__init__.py
@@ -8,16 +8,16 @@
 from funasr.schedulers.warmup_lr import WarmupLR
 
 scheduler_classes = dict(
-	ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
-	lambdalr=torch.optim.lr_scheduler.LambdaLR,
-	steplr=torch.optim.lr_scheduler.StepLR,
-	multisteplr=torch.optim.lr_scheduler.MultiStepLR,
-	exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
-	CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
-	noamlr=NoamLR,
-	warmuplr=WarmupLR,
-	tri_stage=TriStageLR,
-	cycliclr=torch.optim.lr_scheduler.CyclicLR,
-	onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
-	CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+    ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+    lambdalr=torch.optim.lr_scheduler.LambdaLR,
+    steplr=torch.optim.lr_scheduler.StepLR,
+    multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+    exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+    CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+    noamlr=NoamLR,
+    warmuplr=WarmupLR,
+    tri_stage=TriStageLR,
+    cycliclr=torch.optim.lr_scheduler.CyclicLR,
+    onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+    CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
 )
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index 548bf06..136be13 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -1,100 +1,94 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Iterable
-from typing import List
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Union
 import json
-
 import numpy as np
+from abc import ABC
+from pathlib import Path
+from abc import abstractmethod
+from typing import Union, Iterable, List, Dict
 
 
 class AbsTokenizer(ABC):
-	@abstractmethod
-	def text2tokens(self, line: str) -> List[str]:
-		raise NotImplementedError
-	
-	@abstractmethod
-	def tokens2text(self, tokens: Iterable[str]) -> str:
-		raise NotImplementedError
+    @abstractmethod
+    def text2tokens(self, line: str) -> List[str]:
+        raise NotImplementedError
+    
+    @abstractmethod
+    def tokens2text(self, tokens: Iterable[str]) -> str:
+        raise NotImplementedError
 
 
 class BaseTokenizer(ABC):
-	def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
-	             unk_symbol: str = "<unk>",
-	             **kwargs,
-	             ):
-		
-		if token_list is not None:
-			if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
-				token_list = Path(token_list)
-				self.token_list_repr = str(token_list)
-				self.token_list: List[str] = []
-				
-				with token_list.open("r", encoding="utf-8") as f:
-					for idx, line in enumerate(f):
-						line = line.rstrip()
-						self.token_list.append(line)
-			elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
-				token_list = Path(token_list)
-				self.token_list_repr = str(token_list)
-				self.token_list: List[str] = []
-				
-				with open(token_list, 'r', encoding='utf-8') as f:
-					self.token_list = json.load(f)
-			
-			
-			else:
-				self.token_list: List[str] = list(token_list)
-				self.token_list_repr = ""
-				for i, t in enumerate(self.token_list):
-					if i == 3:
-						break
-					self.token_list_repr += f"{t}, "
-				self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-			
-			self.token2id: Dict[str, int] = {}
-			for i, t in enumerate(self.token_list):
-				if t in self.token2id:
-					raise RuntimeError(f'Symbol "{t}" is duplicated')
-				self.token2id[t] = i
-			
-			self.unk_symbol = unk_symbol
-			if self.unk_symbol not in self.token2id:
-				raise RuntimeError(
-					f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
-				)
-			self.unk_id = self.token2id[self.unk_symbol]
-	
-	def encode(self, text):
-		tokens = self.text2tokens(text)
-		text_ints = self.tokens2ids(tokens)
-		
-		return text_ints
-	
-	def decode(self, text_ints):
-		token = self.ids2tokens(text_ints)
-		text = self.tokens2text(token)
-		return text
-	
-	def get_num_vocabulary_size(self) -> int:
-		return len(self.token_list)
-	
-	def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
-		if isinstance(integers, np.ndarray) and integers.ndim != 1:
-			raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
-		return [self.token_list[i] for i in integers]
-	
-	def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
-		return [self.token2id.get(i, self.unk_id) for i in tokens]
-	
-	@abstractmethod
-	def text2tokens(self, line: str) -> List[str]:
-		raise NotImplementedError
-	
-	@abstractmethod
-	def tokens2text(self, tokens: Iterable[str]) -> str:
-		raise NotImplementedError
\ No newline at end of file
+    def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
+                 unk_symbol: str = "<unk>",
+                 **kwargs,
+                 ):
+        
+        if token_list is not None:
+            if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
+                token_list = Path(token_list)
+                self.token_list_repr = str(token_list)
+                self.token_list: List[str] = []
+                
+                with token_list.open("r", encoding="utf-8") as f:
+                    for idx, line in enumerate(f):
+                        line = line.rstrip()
+                        self.token_list.append(line)
+            elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+                token_list = Path(token_list)
+                self.token_list_repr = str(token_list)
+                self.token_list: List[str] = []
+                
+                with open(token_list, 'r', encoding='utf-8') as f:
+                    self.token_list = json.load(f)
+            
+            
+            else:
+                self.token_list: List[str] = list(token_list)
+                self.token_list_repr = ""
+                for i, t in enumerate(self.token_list):
+                    if i == 3:
+                        break
+                    self.token_list_repr += f"{t}, "
+                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+            
+            self.token2id: Dict[str, int] = {}
+            for i, t in enumerate(self.token_list):
+                if t in self.token2id:
+                    raise RuntimeError(f'Symbol "{t}" is duplicated')
+                self.token2id[t] = i
+            
+            self.unk_symbol = unk_symbol
+            if self.unk_symbol not in self.token2id:
+                raise RuntimeError(
+                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+                )
+            self.unk_id = self.token2id[self.unk_symbol]
+    
+    def encode(self, text):
+        tokens = self.text2tokens(text)
+        text_ints = self.tokens2ids(tokens)
+        
+        return text_ints
+    
+    def decode(self, text_ints):
+        token = self.ids2tokens(text_ints)
+        text = self.tokens2text(token)
+        return text
+    
+    def get_num_vocabulary_size(self) -> int:
+        return len(self.token_list)
+    
+    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+        if isinstance(integers, np.ndarray) and integers.ndim != 1:
+            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
+        return [self.token_list[i] for i in integers]
+    
+    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
+        return [self.token2id.get(i, self.unk_id) for i in tokens]
+    
+    @abstractmethod
+    def text2tokens(self, line: str) -> List[str]:
+        raise NotImplementedError
+    
+    @abstractmethod
+    def tokens2text(self, tokens: Iterable[str]) -> str:
+        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 59aeaf0..0f0acc2 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -1,233 +1,235 @@
-import torch
 import os
-from funasr.train_utils.device_funcs import to_device
-import logging
 import time
+import torch
+import logging
 from tqdm import tqdm
-from contextlib import nullcontext
 import torch.distributed as dist
+from contextlib import nullcontext
+
+from funasr.train_utils.device_funcs import to_device
 from funasr.train_utils.recursive_op import recursive_average
 
+
 class Trainer:
-	"""
-	A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
-	and optionally resuming from a saved checkpoint.
+    """
+    A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
+    and optionally resuming from a saved checkpoint.
 
-	Attributes:
-		max_epoch (int): Maximum number of epochs for training.
-		model (torch.nn.Module): The model to be trained.
-		optim (torch.optim.Optimizer): The optimizer to use for training.
-		scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
-		dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
-		dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
-		output_dir (str): Directory where model checkpoints will be saved.
-		resume (str, optional): Path to a checkpoint to resume training from.
-	"""
-	
-	def __init__(self, model,
-	             optim,
-	             scheduler,
-	             dataloader_train,
-	             dataloader_val,
-	             local_rank,
-	             use_ddp=False,
-	             use_fsdp=False,
-	             **kwargs):
-		"""
-		Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
+    Attributes:
+        max_epoch (int): Maximum number of epochs for training.
+        model (torch.nn.Module): The model to be trained.
+        optim (torch.optim.Optimizer): The optimizer to use for training.
+        scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+        dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
+        dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
+        output_dir (str): Directory where model checkpoints will be saved.
+        resume (str, optional): Path to a checkpoint to resume training from.
+    """
+    
+    def __init__(self, model,
+                 optim,
+                 scheduler,
+                 dataloader_train,
+                 dataloader_val,
+                 local_rank,
+                 use_ddp=False,
+                 use_fsdp=False,
+                 **kwargs):
+        """
+        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
 
-		Args:
-			model (torch.nn.Module): The model to be trained.
-			optim (torch.optim.Optimizer): The optimizer to use for training.
-			scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
-			dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
-			dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
-			**kwargs: Additional keyword arguments:
-					  max_epoch (int): The maximum number of epochs for training.
-					  output_dir (str): The directory where model checkpoints will be saved. Default is './'.
-					  resume (str, optional): The file path to a checkpoint to resume training from.
-		"""
-		
-		self.model = model
-		self.optim = optim
-		self.scheduler = scheduler
-		self.dataloader_train = dataloader_train
-		self.dataloader_val = dataloader_val
-		self.output_dir = kwargs.get('output_dir', './')
-		self.resume = kwargs.get('resume', True)
-		self.start_epoch = 0
-		self.max_epoch = kwargs.get('max_epoch', 100)
-		self.local_rank = local_rank
-		self.use_ddp = use_ddp
-		self.use_fsdp = use_fsdp
-		self.device = next(model.parameters()).device
-		self.kwargs = kwargs
-		
-		if self.resume:
-			self._resume_checkpoint(self.resume)
-	
-		try:
-			rank = dist.get_rank()
-			world_size = dist.get_world_size()
-		except:
-			rank = 0
-			world_size = 1
-			logging.warning("distributed is not initialized, only single shard")
-		self.rank = rank
-		self.world_size = world_size
-	
-	def _save_checkpoint(self, epoch):
-		"""
-		Saves a checkpoint containing the model's state, the optimizer's state,
-		and the scheduler's state at the end of the given epoch. This method is
-		intended to be called at the end of each epoch to save the training progress.
+        Args:
+            model (torch.nn.Module): The model to be trained.
+            optim (torch.optim.Optimizer): The optimizer to use for training.
+            scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+            dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
+            dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
+            **kwargs: Additional keyword arguments:
+                      max_epoch (int): The maximum number of epochs for training.
+                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
+                      resume (str, optional): The file path to a checkpoint to resume training from.
+        """
+        
+        self.model = model
+        self.optim = optim
+        self.scheduler = scheduler
+        self.dataloader_train = dataloader_train
+        self.dataloader_val = dataloader_val
+        self.output_dir = kwargs.get('output_dir', './')
+        self.resume = kwargs.get('resume', True)
+        self.start_epoch = 0
+        self.max_epoch = kwargs.get('max_epoch', 100)
+        self.local_rank = local_rank
+        self.use_ddp = use_ddp
+        self.use_fsdp = use_fsdp
+        self.device = next(model.parameters()).device
+        self.kwargs = kwargs
+        
+        if self.resume:
+            self._resume_checkpoint(self.resume)
+    
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+            logging.warning("distributed is not initialized, only single shard")
+        self.rank = rank
+        self.world_size = world_size
+    
+    def _save_checkpoint(self, epoch):
+        """
+        Saves a checkpoint containing the model's state, the optimizer's state,
+        and the scheduler's state at the end of the given epoch. This method is
+        intended to be called at the end of each epoch to save the training progress.
 
-		Args:
-			epoch (int): The epoch number at which the checkpoint is being saved.
-		"""
-		state = {
-			'epoch': epoch,
-			'state_dict': self.model.state_dict(),
-			'optimizer': self.optim.state_dict(),
-			'scheduler': self.scheduler.state_dict(),
-		}
-		# Create output directory if it does not exist
-		os.makedirs(self.output_dir, exist_ok=True)
-		filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
-		torch.save(state, filename)
-		print(f'Checkpoint saved to {filename}')
-	
-	def _resume_checkpoint(self, resume_path):
-		"""
-		Resumes training from a checkpoint at the given file path.
-		Loads the model's state, the optimizer's state, and the scheduler's state.
+        Args:
+            epoch (int): The epoch number at which the checkpoint is being saved.
+        """
+        state = {
+            'epoch': epoch,
+            'state_dict': self.model.state_dict(),
+            'optimizer': self.optim.state_dict(),
+            'scheduler': self.scheduler.state_dict(),
+        }
+        # Create output directory if it does not exist
+        os.makedirs(self.output_dir, exist_ok=True)
+        filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
+        torch.save(state, filename)
+        print(f'Checkpoint saved to {filename}')
+    
+    def _resume_checkpoint(self, resume_path):
+        """
+        Resumes training from a checkpoint at the given file path.
+        Loads the model's state, the optimizer's state, and the scheduler's state.
 
-		Args:
-			resume_path (str): The file path to the checkpoint to resume from.
-		"""
-		if os.path.isfile(resume_path):
-			checkpoint = torch.load(resume_path)
-			self.start_epoch = checkpoint['epoch'] + 1
-			self.model.load_state_dict(checkpoint['state_dict'])
-			self.optim.load_state_dict(checkpoint['optimizer'])
-			self.scheduler.load_state_dict(checkpoint['scheduler'])
-			print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
-		else:
-			print(f"No checkpoint found at '{resume_path}', starting from scratch")
-		
-	def run(self):
-		"""
-		Starts the training process, iterating over epochs, training the model,
-		and saving checkpoints at the end of each epoch.
-		"""
-		for epoch in range(self.start_epoch, self.max_epoch + 1):
-			self._train_epoch(epoch)
-			# self._validate_epoch(epoch)
-			if self.rank == 0:
-				self._save_checkpoint(epoch)
-			self.scheduler.step()
-	
-	def _train_epoch(self, epoch):
-		"""
-		Defines the training process for a single epoch with gradient accumulation.
-		Args:
-			epoch (int): The current epoch number.
-		"""
-		self.model.train()
-		pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
-		            dynamic_ncols=True)
-		
-		# Set the number of steps for gradient accumulation
-		accum_grad = self.kwargs.get("accum_grad", 1)
-		# Initialize the gradient accumulation
-		self.optim.zero_grad()
-		speed_stats = {}
-		time5 = time.perf_counter()
-		for batch_idx, batch in enumerate(self.dataloader_train):
-			time1 = time.perf_counter()
-			speed_stats["data_load"] = f"{time1-time5:0.3f}"
-			# import pdb;
-			# pdb.set_trace()
-			batch = to_device(batch, self.device)
-			
-			my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
-			with my_context():
-				time2 = time.perf_counter()
-				retval = self.model(**batch)
-				time3 = time.perf_counter()
-				speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
-				loss, stats, weight = retval
-				stats = {k: v for k, v in stats.items() if v is not None}
-				if self.use_ddp or self.use_fsdp:
-					# Apply weighted averaging for loss and stats
-					loss = (loss * weight.type(loss.dtype)).sum()
-					# if distributed, this method can also apply all_reduce()
-					stats, weight = recursive_average(stats, weight, distributed=True)
-					# Now weight is summation over all workers
-					loss /= weight
-					# Multiply world_size because DistributedDataParallel
-					# automatically normalizes the gradient by world_size.
-					loss *= self.world_size
-				# Scale the loss since we're not updating for every mini-batch
-				loss = loss / accum_grad
-				loss.backward()
-				time4 = time.perf_counter()
-				speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
-			
-			# Perform an optimizer step only after accumulating enough gradients
-			if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
-				# Perform gradient clipping if it is set
-				if self.kwargs.get("grad_clip", None) is not None:
-					grad_norm = torch.nn.utils.clip_grad_norm_(
-						self.model.parameters(),
-						max_norm=self.kwargs.get("grad_clip", 10.0),
-						norm_type=self.kwargs.get("grad_clip_type", 2.0),
-					)
-					if not torch.isfinite(grad_norm):
-						logging.warning(
-							f"The grad norm is {grad_norm}. Skipping updating the model."
-						)
-						self.optim.zero_grad()  # Reset gradients
-						continue
-				
-				# Execute an optimization step (update model parameters)
-				self.optim.step()
-				self.scheduler.step()
-				# Clear gradients for the next accumulation stage
-				self.optim.zero_grad()
-				total_time = f"{time.perf_counter() - time5:0.3f}"
-				time5 = time.perf_counter()
-				speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
-	
-				speed_stats["total_time"] = total_time
-			
-			# import pdb;
-			# pdb.set_trace()
-			pbar.update(1)
-			if self.local_rank == 0:
-				description = (
-					f"Epoch: {epoch + 1}/{self.max_epoch}, "
-					f"step {batch_idx}/{len(self.dataloader_train)}, "
-					f"{speed_stats}, "
-					f"(loss: {loss.detach().cpu().item():.3f}), "
-					f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
-				)
-				pbar.set_description(description)
-			
-			# if batch_idx == 2:
-			# 	break
-		pbar.close()
+        Args:
+            resume_path (str): The file path to the checkpoint to resume from.
+        """
+        if os.path.isfile(resume_path):
+            checkpoint = torch.load(resume_path)
+            self.start_epoch = checkpoint['epoch'] + 1
+            self.model.load_state_dict(checkpoint['state_dict'])
+            self.optim.load_state_dict(checkpoint['optimizer'])
+            self.scheduler.load_state_dict(checkpoint['scheduler'])
+            print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
+        else:
+            print(f"No checkpoint found at '{resume_path}', starting from scratch")
+        
+    def run(self):
+        """
+        Starts the training process, iterating over epochs, training the model,
+        and saving checkpoints at the end of each epoch.
+        """
+        for epoch in range(self.start_epoch, self.max_epoch + 1):
+            self._train_epoch(epoch)
+            # self._validate_epoch(epoch)
+            if self.rank == 0:
+                self._save_checkpoint(epoch)
+            self.scheduler.step()
+    
+    def _train_epoch(self, epoch):
+        """
+        Defines the training process for a single epoch with gradient accumulation.
+        Args:
+            epoch (int): The current epoch number.
+        """
+        self.model.train()
+        pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
+                    dynamic_ncols=True)
+        
+        # Set the number of steps for gradient accumulation
+        accum_grad = self.kwargs.get("accum_grad", 1)
+        # Initialize the gradient accumulation
+        self.optim.zero_grad()
+        speed_stats = {}
+        time5 = time.perf_counter()
+        for batch_idx, batch in enumerate(self.dataloader_train):
+            time1 = time.perf_counter()
+            speed_stats["data_load"] = f"{time1-time5:0.3f}"
+            # import pdb;
+            # pdb.set_trace()
+            batch = to_device(batch, self.device)
+            
+            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+            with my_context():
+                time2 = time.perf_counter()
+                retval = self.model(**batch)
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
+                loss, stats, weight = retval
+                stats = {k: v for k, v in stats.items() if v is not None}
+                if self.use_ddp or self.use_fsdp:
+                    # Apply weighted averaging for loss and stats
+                    loss = (loss * weight.type(loss.dtype)).sum()
+                    # if distributed, this method can also apply all_reduce()
+                    stats, weight = recursive_average(stats, weight, distributed=True)
+                    # Now weight is summation over all workers
+                    loss /= weight
+                    # Multiply world_size because DistributedDataParallel
+                    # automatically normalizes the gradient by world_size.
+                    loss *= self.world_size
+                # Scale the loss since we're not updating for every mini-batch
+                loss = loss / accum_grad
+                loss.backward()
+                time4 = time.perf_counter()
+                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+            
+            # Perform an optimizer step only after accumulating enough gradients
+            if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
+                # Perform gradient clipping if it is set
+                if self.kwargs.get("grad_clip", None) is not None:
+                    grad_norm = torch.nn.utils.clip_grad_norm_(
+                        self.model.parameters(),
+                        max_norm=self.kwargs.get("grad_clip", 10.0),
+                        norm_type=self.kwargs.get("grad_clip_type", 2.0),
+                    )
+                    if not torch.isfinite(grad_norm):
+                        logging.warning(
+                            f"The grad norm is {grad_norm}. Skipping updating the model."
+                        )
+                        self.optim.zero_grad()  # Reset gradients
+                        continue
+                
+                # Execute an optimization step (update model parameters)
+                self.optim.step()
+                self.scheduler.step()
+                # Clear gradients for the next accumulation stage
+                self.optim.zero_grad()
+                total_time = f"{time.perf_counter() - time5:0.3f}"
+                time5 = time.perf_counter()
+                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
+    
+                speed_stats["total_time"] = total_time
+            
+            # import pdb;
+            # pdb.set_trace()
+            pbar.update(1)
+            if self.local_rank == 0:
+                description = (
+                    f"Epoch: {epoch + 1}/{self.max_epoch}, "
+                    f"step {batch_idx}/{len(self.dataloader_train)}, "
+                    f"{speed_stats}, "
+                    f"(loss: {loss.detach().cpu().item():.3f}), "
+                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+                )
+                pbar.set_description(description)
+            
+            # if batch_idx == 2:
+            #     break
+        pbar.close()
 
-	def _validate_epoch(self, epoch):
-		"""
-		Defines the validation process for a single epoch.
-		Should be implemented with the actual model validation steps.
-	
-		Args:
-			epoch (int): The current epoch number.
-		"""
-		self.model.eval()
-		with torch.no_grad():
-			for data, target in self.dataloader_val:
-				# Implement the model validation steps here
-				pass
+    def _validate_epoch(self, epoch):
+        """
+        Defines the validation process for a single epoch.
+        Should be implemented with the actual model validation steps.
+    
+        Args:
+            epoch (int): The current epoch number.
+        """
+        self.model.eval()
+        with torch.no_grad():
+            for data, target in self.dataloader_val:
+                # Implement the model validation steps here
+                pass
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 4e131a8..9cd3854 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -10,100 +10,100 @@
 import logging
 from torch.nn.utils.rnn import pad_sequence
 try:
-	from funasr.download.file import download_from_url
+    from funasr.download.file import download_from_url
 except:
-	print("urllib is not installed, if you infer from url, please install it first.")
+    print("urllib is not installed, if you infer from url, please install it first.")
 
 
 
 def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
-	if isinstance(data_or_path_or_list, (list, tuple)):
-		if data_type is not None and isinstance(data_type, (list, tuple)):
+    if isinstance(data_or_path_or_list, (list, tuple)):
+        if data_type is not None and isinstance(data_type, (list, tuple)):
 
-			data_types = [data_type] * len(data_or_path_or_list)
-			data_or_path_or_list_ret = [[] for d in data_type]
-			for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
-				
-				for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
-					
-					data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
-					data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
+            data_types = [data_type] * len(data_or_path_or_list)
+            data_or_path_or_list_ret = [[] for d in data_type]
+            for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
+                
+                for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
+                    
+                    data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
+                    data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
 
-			return data_or_path_or_list_ret
-		else:
-			return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
-	
-	if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
-		data_or_path_or_list = download_from_url(data_or_path_or_list)
-	
-	if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
-		if data_type is None or data_type == "sound":
-			data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
-			data_or_path_or_list = data_or_path_or_list[0, :]
-		elif data_type == "text" and tokenizer is not None:
-			data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
-		elif data_type == "image": # undo
-			pass
-		elif data_type == "video": # undo
-			pass
-		
-		# if data_in is a file or url, set is_final=True
-		if "cache" in kwargs:
-			kwargs["cache"]["is_final"] = True
-	elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
-		data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
-	elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
-		data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
-	else:
-		pass
-		# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
-		
-	if audio_fs != fs and data_type != "text":
-		resampler = torchaudio.transforms.Resample(audio_fs, fs)
-		data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
-	return data_or_path_or_list
+            return data_or_path_or_list_ret
+        else:
+            return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
+    
+    if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
+        data_or_path_or_list = download_from_url(data_or_path_or_list)
+    
+    if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
+        if data_type is None or data_type == "sound":
+            data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
+            data_or_path_or_list = data_or_path_or_list[0, :]
+        elif data_type == "text" and tokenizer is not None:
+            data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+        elif data_type == "image": # undo
+            pass
+        elif data_type == "video": # undo
+            pass
+        
+        # if data_in is a file or url, set is_final=True
+        if "cache" in kwargs:
+            kwargs["cache"]["is_final"] = True
+    elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
+        data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
+    elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
+        data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
+    else:
+        pass
+        # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
+        
+    if audio_fs != fs and data_type != "text":
+        resampler = torchaudio.transforms.Resample(audio_fs, fs)
+        data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
+    return data_or_path_or_list
 
 def load_bytes(input):
-	middle_data = np.frombuffer(input, dtype=np.int16)
-	middle_data = np.asarray(middle_data)
-	if middle_data.dtype.kind not in 'iu':
-		raise TypeError("'middle_data' must be an array of integers")
-	dtype = np.dtype('float32')
-	if dtype.kind != 'f':
-		raise TypeError("'dtype' must be a floating point type")
-	
-	i = np.iinfo(middle_data.dtype)
-	abs_max = 2 ** (i.bits - 1)
-	offset = i.min + abs_max
-	array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
-	return array
+    middle_data = np.frombuffer(input, dtype=np.int16)
+    middle_data = np.asarray(middle_data)
+    if middle_data.dtype.kind not in 'iu':
+        raise TypeError("'middle_data' must be an array of integers")
+    dtype = np.dtype('float32')
+    if dtype.kind != 'f':
+        raise TypeError("'dtype' must be a floating point type")
+    
+    i = np.iinfo(middle_data.dtype)
+    abs_max = 2 ** (i.bits - 1)
+    offset = i.min + abs_max
+    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+    return array
 
 def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
-	# import pdb;
-	# pdb.set_trace()
-	if isinstance(data, np.ndarray):
-		data = torch.from_numpy(data)
-		if len(data.shape) < 2:
-			data = data[None, :] # data: [batch, N]
-		data_len = [data.shape[1]] if data_len is None else data_len
-	elif isinstance(data, torch.Tensor):
-		if len(data.shape) < 2:
-			data = data[None, :] # data: [batch, N]
-		data_len = [data.shape[1]] if data_len is None else data_len
-	elif isinstance(data, (list, tuple)):
-		data_list, data_len = [], []
-		for data_i in data:
-			if isinstance(data_i, np.ndarray):
-				data_i = torch.from_numpy(data_i)
-			data_list.append(data_i)
-			data_len.append(data_i.shape[0])
-		data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
-	# import pdb;
-	# pdb.set_trace()
-	# if data_type == "sound":
-	data, data_len = frontend(data, data_len, **kwargs)
-	
-	if isinstance(data_len, (list, tuple)):
-		data_len = torch.tensor([data_len])
-	return data.to(torch.float32), data_len.to(torch.int32)
+    # import pdb;
+    # pdb.set_trace()
+    if isinstance(data, np.ndarray):
+        data = torch.from_numpy(data)
+        if len(data.shape) < 2:
+            data = data[None, :] # data: [batch, N]
+        data_len = [data.shape[1]] if data_len is None else data_len
+    elif isinstance(data, torch.Tensor):
+        if len(data.shape) < 2:
+            data = data[None, :] # data: [batch, N]
+        data_len = [data.shape[1]] if data_len is None else data_len
+    elif isinstance(data, (list, tuple)):
+        data_list, data_len = [], []
+        for data_i in data:
+            if isinstance(data_i, np.ndarray):
+                data_i = torch.from_numpy(data_i)
+            data_list.append(data_i)
+            data_len.append(data_i.shape[0])
+        data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
+    # import pdb;
+    # pdb.set_trace()
+    # if data_type == "sound":
+    data, data_len = frontend(data, data_len, **kwargs)
+    
+    if isinstance(data_len, (list, tuple)):
+        data_len = torch.tensor([data_len])
+    return data.to(torch.float32), data_len.to(torch.int32)
 
diff --git a/funasr/utils/vad_utils.py b/funasr/utils/vad_utils.py
index f84e2b9..af7c8f2 100644
--- a/funasr/utils/vad_utils.py
+++ b/funasr/utils/vad_utils.py
@@ -1,31 +1,31 @@
 import torch
 from torch.nn.utils.rnn import pad_sequence
 
-def slice_padding_fbank(speech, speech_lengths, vad_segments):
-	speech_list = []
-	speech_lengths_list = []
-	for i, segment in enumerate(vad_segments):
-		
-		bed_idx = int(segment[0][0]*16)
-		end_idx = min(int(segment[0][1]*16), speech_lengths[0])
-		speech_i = speech[0, bed_idx: end_idx]
-		speech_lengths_i = end_idx-bed_idx
-		speech_list.append(speech_i)
-		speech_lengths_list.append(speech_lengths_i)
-	feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
-	speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
-	return feats_pad, speech_lengths_pad
 
+def slice_padding_fbank(speech, speech_lengths, vad_segments):
+    speech_list = []
+    speech_lengths_list = []
+    for i, segment in enumerate(vad_segments):
+        
+        bed_idx = int(segment[0][0]*16)
+        end_idx = min(int(segment[0][1]*16), speech_lengths[0])
+        speech_i = speech[0, bed_idx: end_idx]
+        speech_lengths_i = end_idx-bed_idx
+        speech_list.append(speech_i)
+        speech_lengths_list.append(speech_lengths_i)
+    feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
+    speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
+    return feats_pad, speech_lengths_pad
 
 def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
-	speech_list = []
-	speech_lengths_list = []
-	for i, segment in enumerate(vad_segments):
-		bed_idx = int(segment[0][0] * 16)
-		end_idx = min(int(segment[0][1] * 16), speech_lengths)
-		speech_i = speech[bed_idx: end_idx]
-		speech_lengths_i = end_idx - bed_idx
-		speech_list.append(speech_i)
-		speech_lengths_list.append(speech_lengths_i)
-		
-	return speech_list, speech_lengths_list
\ No newline at end of file
+    speech_list = []
+    speech_lengths_list = []
+    for i, segment in enumerate(vad_segments):
+        bed_idx = int(segment[0][0] * 16)
+        end_idx = min(int(segment[0][1] * 16), speech_lengths)
+        speech_i = speech[bed_idx: end_idx]
+        speech_lengths_i = end_idx - bed_idx
+        speech_list.append(speech_i)
+        speech_lengths_list.append(speech_lengths_i)
+        
+    return speech_list, speech_lengths_list
\ No newline at end of file
diff --git a/runtime/python/utils/test_cer.py b/runtime/python/utils/test_cer.py
index e27e393..d795d33 100644
--- a/runtime/python/utils/test_cer.py
+++ b/runtime/python/utils/test_cer.py
@@ -17,8 +17,8 @@
 
 from funasr.runtime.python.libtorch.funasr_torch import Paraformer
 if args.backend == "onnx":
-	from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
-	
+    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
+    
 model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
 
 wav_file_f = open(args.wav_file, 'r')
@@ -26,23 +26,23 @@
 
 output_dir = args.output_dir
 if not os.path.exists(output_dir):
-	os.makedirs(output_dir)
+    os.makedirs(output_dir)
 if os.name == 'nt':   # Windows
-	newline = '\r\n'
+    newline = '\r\n'
 else:   # Linux Mac
-	newline = '\n'
+    newline = '\n'
 text_f = open(os.path.join(output_dir, "text"), "w", newline=newline)
 token_f = open(os.path.join(output_dir, "token"), "w", newline=newline)
 
 for i, wav_path_i in enumerate(wav_files):
-	wav_name, wav_path = wav_path_i.strip().split()
-	result = model(wav_path)
-	text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
-	token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
-	text_f.write(text_i)
-	text_f.flush()
-	token_f.write(token_i)
-	token_f.flush()
+    wav_name, wav_path = wav_path_i.strip().split()
+    result = model(wav_path)
+    text_i = "{} {}\n".format(wav_name, result[0]['preds'][0])
+    token_i = "{} {}\n".format(wav_name, result[0]['preds'][1])
+    text_f.write(text_i)
+    text_f.flush()
+    token_f.write(token_i)
+    token_f.flush()
 text_f.close()
 token_f.close()
-	
+    
diff --git a/runtime/python/utils/test_rtf.py b/runtime/python/utils/test_rtf.py
index 391a0ac..3fe96a3 100644
--- a/runtime/python/utils/test_rtf.py
+++ b/runtime/python/utils/test_rtf.py
@@ -16,8 +16,8 @@
 
 from funasr.runtime.python.libtorch.funasr_torch import Paraformer
 if args.backend == "onnx":
-	from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
-	
+    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
+    
 model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
 
 wav_file_f = open(args.wav_file, 'r')
@@ -28,28 +28,28 @@
 num = 30
 wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
 for i in range(num):
-	beg_time = time.time()
-	result = model(wav_path)
-	end_time = time.time()
-	duration = end_time-beg_time
-	total += duration
-	print(result)
-	print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
+    beg_time = time.time()
+    result = model(wav_path)
+    end_time = time.time()
+    duration = end_time-beg_time
+    total += duration
+    print(result)
+    print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
 
 # infer time
 beg_time = time.time()
 for i, wav_path_i in enumerate(wav_files):
-	wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
-	result = model(wav_path)
+    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+    result = model(wav_path)
 end_time = time.time()
 duration = (end_time-beg_time)*1000
 print("total_time_comput_ms: {}".format(int(duration)))
 
 duration_time = 0.0
 for i, wav_path_i in enumerate(wav_files):
-	wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
-	waveform, _ = librosa.load(wav_path, sr=16000)
-	duration_time += len(waveform)/16.0
+    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+    waveform, _ = librosa.load(wav_path, sr=16000)
+    duration_time += len(waveform)/16.0
 print("total_time_wav_ms: {}".format(int(duration_time)))
 
 print("total_rtf: {:.5}".format(duration/duration_time))
\ No newline at end of file
diff --git a/runtime/python/utils/test_rtf_gpu.py b/runtime/python/utils/test_rtf_gpu.py
index 84cd2c7..02d5ac6 100644
--- a/runtime/python/utils/test_rtf_gpu.py
+++ b/runtime/python/utils/test_rtf_gpu.py
@@ -17,8 +17,8 @@
 
 from funasr.runtime.python.libtorch.funasr_torch import Paraformer
 if args.backend == "onnx":
-	from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
-	
+    from funasr.runtime.python.onnxruntime.funasr_onnx import Paraformer
+    
 model = Paraformer(args.model_dir, batch_size=args.batch_size, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
 
 wav_file_f = open(args.wav_file, 'r')
@@ -29,20 +29,20 @@
 num = 30
 wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
 for i in range(num):
-	beg_time = time.time()
-	result = model(wav_path)
-	end_time = time.time()
-	duration = end_time-beg_time
-	total += duration
-	print(result)
-	print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
+    beg_time = time.time()
+    result = model(wav_path)
+    end_time = time.time()
+    duration = end_time-beg_time
+    total += duration
+    print(result)
+    print("num: {}, time, {}, avg: {}, rtf: {}".format(len(wav_path), duration, total/(i+1), (total/(i+1))/5.53))
 
 # infer time
 wav_path = []
 beg_time = time.time()
 for i, wav_path_i in enumerate(wav_files):
-	wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
-	wav_path += [wav_path_i]
+    wav_path_i = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+    wav_path += [wav_path_i]
 result = model(wav_path)
 end_time = time.time()
 duration = (end_time-beg_time)*1000
@@ -50,9 +50,9 @@
 
 duration_time = 0.0
 for i, wav_path_i in enumerate(wav_files):
-	wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
-	waveform, _ = librosa.load(wav_path, sr=16000)
-	duration_time += len(waveform)/16.0
+    wav_path = wav_path_i.split("\t")[1].strip() if "\t" in wav_path_i else wav_path_i.split(" ")[1].strip()
+    waveform, _ = librosa.load(wav_path, sr=16000)
+    duration_time += len(waveform)/16.0
 print("total_time_wav_ms: {}".format(int(duration_time)))
 
 print("total_rtf: {:.5}".format(duration/duration_time))
\ No newline at end of file

--
Gitblit v1.9.1