zhifu gao
2024-03-21 3ac03e448b7673604eb86f619b27521fca55f34d
train & finetune llm-asr (#1519)

* trainer

* trainer

* trainer

* trainer

* trainer

* trainer

* trainer

* trainer

* trainer

* trainer

* trainer
8个文件已修改
451 ■■■■ 已修改文件
examples/industrial_data_pretraining/paraformer_streaming/demo.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train_llm.py 54 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_vicuna/samplers.py 197 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/schedulers/lambdalr_cus.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/average_nbest_models.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer_llm.py 174 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -8,20 +8,21 @@
chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
wav_file="/Users/zhifu/Downloads/NCYzUhAtZNI_0015.wav"
model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
res = model.generate(input=wav_file,
            chunk_size=chunk_size,
            encoder_chunk_look_back=encoder_chunk_look_back,
            decoder_chunk_look_back=decoder_chunk_look_back,
            )
print(res)
# exit()
import soundfile
import os
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
# wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
funasr/bin/train_llm.py
@@ -4,18 +4,21 @@
import os
import sys
import torch
import torch.nn as nn
import hydra
import logging
import time
import argparse
from io import BytesIO
from contextlib import nullcontext
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from funasr.train_utils.average_nbest_models import average_checkpoints
@@ -48,7 +51,6 @@
def main(**kwargs):
    print(kwargs)
    
    # set random seed
    set_all_random_seed(kwargs.get("seed", 0))
@@ -61,11 +63,13 @@
        tables.print()
    # Check if we are using DDP or FSDP
    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_fsdp = kwargs.get("use_fsdp", None)
    use_fsdp = kwargs.get("use_fsdp", False)
    # use_ddp = False if use_fsdp else use_fsdp
    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
        torch.cuda.set_device(local_rank)
        
    logging.info("Build model, frontend, tokenizer")
    device = kwargs.get("device", "cuda")
    kwargs["device"] = "cpu"
    model = AutoModel(**kwargs)
@@ -76,6 +80,7 @@
        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
        OmegaConf.save(config=kwargs, f=yaml_file)
        print(kwargs)
        logging.info("config.yaml is saved to: %s", yaml_file)
    
    # parse kwargs
@@ -105,19 +110,42 @@
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
    elif use_fsdp:
        model = FSDP(model).cuda(local_rank)
        # model = FSDP(model).cuda(local_rank)
        def custom_auto_wrap_policy(
            module: nn.Module,
            recurse: bool,
            nonwrapped_numel: int,
            # Additional custom arguments
            min_num_params: int = int(1e8),
        ) -> bool:
            # 根据自定义逻辑决定是否包装模块
            is_large = unwrapped_params >= min_num_params
            requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
            return is_large and requires_grad_uniform
        # Configure a custom `min_num_params`
        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
        torch.cuda.set_device(local_rank)
        model = FSDP(model,
                     auto_wrap_policy=custom_auto_wrap_policy,
                     mixed_precision=None,
                     device_id=torch.cuda.current_device())
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    logging.info(f"{model}")
    kwargs["device"] = next(model.parameters()).device
        
    # optim
    logging.info("Build optim")
    optim = kwargs.get("optim", "adam")
    assert optim in optim_classes
    optim_class = optim_classes.get(optim)
    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
    
    # scheduler
    logging.info("Build scheduler")
    scheduler = kwargs.get("scheduler", "warmuplr")
    assert scheduler in scheduler_classes
    scheduler_class = scheduler_classes.get(scheduler)
@@ -125,6 +153,7 @@
    # dataset
    logging.info("Build dataloader")
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
@@ -142,8 +171,9 @@
    trainer = Trainer(local_rank=local_rank,
                      use_ddp=use_ddp,
                      resume=kwargs.get("resume", True),
                      use_fsdp=use_fsdp,
                      device=kwargs["device"],
                      output_dir=kwargs.get("output_dir", "./exp"),
                      **kwargs.get("train_conf"),
                      )
@@ -160,8 +190,15 @@
    except:
        writer = None
    
    if use_ddp or use_fsdp:
        context = Join([model])
    else:
        context = nullcontext()
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        with context:
        trainer.train_epoch(
                            model=model,
                            optim=optim,
@@ -172,7 +209,7 @@
                            epoch=epoch,
                            writer=writer
                            )
        scheduler.step()
        trainer.validate_epoch(
            model=model,
            dataloader_val=dataloader_val,
@@ -180,21 +217,20 @@
            writer=writer
        )
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
        scheduler.step()
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
        time2 = time.perf_counter()
        time_escaped = (time2 - time1) / 3600.0
        logging.info(
            f"\nrank: {local_rank}, "
            f"rank: {local_rank}, "
            f"time_escaped_epoch: {time_escaped:.3f} hours, "
            f"estimated to finish {trainer.max_epoch} "
            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
    if trainer.rank == 0:
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
    trainer.close()
funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -232,3 +232,200 @@
    def set_epoch(self, epoch):
        self.epoch = epoch
@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler_fn")
def CustomDistributedBatchSampler_fn(dataset, **kwargs):
    dataloader_args = {}
    dataloader_args["batch_sampler"] = CustomDistributedBufferBatchSampler(dataset, **kwargs)
    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
    return dataloader_args
@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler")
class CustomDistributedBatchSampler(Sampler):
    def __init__(self, dataset,
                 batch_size,
                 num_replicas=None,
                 rank=None,
                 shuffle=True,
                 drop_last=False,
                 is_training: bool = True,
                 sort_size: int = 1024,
                 **kwargs,
                 ):
        try:
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
        except:
            rank = 0
            num_replicas = 1
        self.rank = rank
        self.num_replicas = num_replicas
        self.dataset = dataset
        self.batch_size = batch_size
        self.is_training = is_training
        self.shuffle = shuffle and is_training
        self.drop_last = drop_last
        # self.total_size = len(dataset)
        if self.drop_last:
            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
        else:
            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
        self.num_samples = int(self.total_size // self.num_replicas)
        self.epoch = 0
        self.max_token_length = kwargs.get("max_token_length", None)
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        self.sort_size = sort_size
    def __iter__(self):
        # Generate a list of indices
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))
        # Add extra samples to make it evenly divisible
        padding_size = self.total_size - len(indices)
        if padding_size <= len(indices):
            indices += indices[:padding_size]
        else:
            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
        assert len(indices) == self.total_size
        # Subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples
        # Filter out indices with length greater than the max length, if provided
        if self.max_token_length is not None:
            filtered_indices = []
            for idx in indices:
                source_len = self.dataset.get_source_len(idx) / self.length_scale_source
                if source_len <= self.max_token_length:
                    filtered_indices.append(idx)
            indices = filtered_indices
        # Buffer sorting logic
        sorted_batches = []
        buffer = []
        for idx in indices:
            buffer.append(idx)
            if len(buffer) >= self.sort_size:
                # Sort the buffer based on some criteria, e.g., dataset sample length
                buffer.sort(key=lambda x: self.dataset.get_source_len(x))
                sorted_batches.extend(self._create_batches_from_buffer(buffer))
                buffer = []
        # Handle the remaining items in the buffer
        if buffer:
            buffer.sort(key=lambda x: self.dataset.get_source_len(x))
            sorted_batches.extend(self._create_batches_from_buffer(buffer))
        return iter(sorted_batches)
    def _create_batches_from_buffer(self, buffer):
        # Function to convert the sorted buffer into batches
        batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
        if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
            batched_buffer = batched_buffer[:-1]
        return batched_buffer
    def __len__(self):
        return self.num_samples // self.batch_size
    def set_epoch(self, epoch):
        self.epoch = epoch
@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler_fn")
def CustomDistributedBatchSampler_fn(dataset, **kwargs):
    dataloader_args = {}
    dataloader_args["batch_sampler"] = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
    return dataloader_args
@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler")
class CustomDistributedDynamicBatchSampler(Sampler):
    def __init__(self, dataset,
                 batch_size,
                 num_replicas=None,
                 rank=None,
                 shuffle=True,
                 drop_last=False,
                 is_training: bool = True,
                 **kwargs,
                 ):
        try:
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
        except:
            rank = 0
            num_replicas = 1
        self.rank = rank
        self.num_replicas = num_replicas
        self.dataset = dataset
        self.batch_size = batch_size
        self.is_training = is_training
        self.shuffle = shuffle and is_training
        self.drop_last = drop_last
        self.total_size = len(self.dataset)
        # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
        self.epoch = 0
    def __iter__(self):
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))
        indices = indices[self.rank:self.total_size:self.num_replicas]
        batches = []
        batch = []
        max_len_in_batch = 0
        current_batch_length = 0
        for idx in indices:
            sample_length = self.dataset.get_source_len(idx)
            potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
                    len(batch) + 1)
            if potential_batch_length <= self.batch_size:
                batch.append(idx)
                if sample_length > max_len_in_batch:
                    max_len_in_batch = sample_length
                    current_batch_length = max_len_in_batch * len(batch)
            else:
                batches.append(batch)
                batch = [idx]
                max_len_in_batch = sample_length
                current_batch_length = max_len_in_batch
        # Add the last batch if it's not empty and we're not dropping it
        if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
            batches.append(batch)
        return iter(batches)
    def __len__(self):
        return -1
    def set_epoch(self, epoch):
        self.epoch = epoch
funasr/models/paraformer/model.py
@@ -231,6 +231,7 @@
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
funasr/schedulers/lambdalr_cus.py
@@ -15,3 +15,18 @@
            ]
        else:
            return [base_lr for base_lr in self.base_lrs]
class CustomLambdaLR(_LRScheduler):
    def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
        self.warmup_steps = train_config.warmup_steps
        self.total_steps = train_config.total_steps
        super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
    def get_lr(self):
        step = self._step_count
        if step < self.warmup_steps:
            lr_scale = step / self.warmup_steps
        else:
            lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
        return [base_lr * lr_scale for base_lr in self.base_lrs]
funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@
    return checkpoint_paths
@torch.no_grad()
def average_checkpoints(output_dir: str, last_n: int=5):
def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
funasr/train_utils/trainer.py
@@ -103,6 +103,7 @@
        
        os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
        self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
        
    
    def _save_checkpoint(self, epoch, step=None):
funasr/train_utils/trainer_llm.py
@@ -1,3 +1,4 @@
import math
import os
import time
import torch
@@ -61,6 +62,8 @@
        """
        
        self.output_dir = output_dir
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)
        self.resume = kwargs.get('resume', True)
        self.start_epoch = 0
        self.max_epoch = kwargs.get('max_epoch', 100)
@@ -78,6 +81,7 @@
        # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
        # self.scaler = scaler
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
        self.accum_grad = kwargs.get("accum_grad", 1)
        self.grad_clip = kwargs.get("grad_clip", 10.0)
        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
@@ -93,6 +97,15 @@
            logging.warning("distributed is not initialized, only single shard")
        self.rank = rank
        self.world_size = world_size
        self.train_acc_avg = 0.0
        self.train_loss_avg = 0.0
        self.val_acc_avg = 0.0
        self.val_loss_avg = 0.0
        self.best_acc_idx = 0
        self.saved_ckpts = {}
        self.val_acc_list = []
        self.step_or_epoch = -1
        
        
@@ -112,27 +125,55 @@
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        if self.rank == 0:
            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            self.step_or_epoch += 1
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optim.state_dict(),
                'scheduler': scheduler.state_dict(),
                "acc": self.val_acc_list,
                "step_or_epoch": self.step_or_epoch,
            }
            if hasattr(model, "module"):
                state["state_dict"] = model.module.state_dict()
            if scaler:
                state["scaler_state"] = scaler.state_dict()
            # Create output directory if it does not exist
            os.makedirs(self.output_dir, exist_ok=True)
            if step is None:
                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
                ckpt_name = f'model.pt.ep{epoch}'
            else:
                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
                ckpt_name = f'model.pt.ep{epoch}.{step}'
            filename = os.path.join(self.output_dir, ckpt_name)
            torch.save(state, filename)
            
            print(f'\nCheckpoint saved to {filename}\n')
            logging.info(f'\nCheckpoint saved to {filename}\n')
            latest = Path(os.path.join(self.output_dir, f'model.pt'))
            torch.save(state, latest)
            if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
                self.best_acc_idx = self.step_or_epoch
                best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
                torch.save(state, best_ckpt)
                logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
            else:
                logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
            if self.keep_nbest_models > 0:
                self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
                if len(self.saved_ckpts) > self.keep_nbest_models:
                    min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
                    if min_key in self.saved_ckpts:
                        del self.saved_ckpts[min_key]
                    filename = os.path.join(self.output_dir, min_key)
                    logging.info(f"Delete: {filename}")
                    if os.path.exists(filename):
                        os.remove(filename)
        
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
@@ -173,6 +214,10 @@
                scheduler.load_state_dict(checkpoint['scheduler'])
                if scaler is not None and 'scaler_state' in checkpoint:
                    scaler.load_state_dict(checkpoint['scaler_state'])
                self.val_acc_list = checkpoint["acc"]
                self.step_or_epoch = checkpoint["step_or_epoch"]
                print(f"Checkpoint loaded successfully from '{ckpt}'")
            else:
                print(f"No checkpoint found at '{ckpt}', does not resume status!")
@@ -180,51 +225,6 @@
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        
    # def train(self):
    #     """
    #     Starts the training process, iterating over epochs, training the model,
    #     and saving checkpoints at the end of each epoch.
    #     """
    #     if self.resume:
    #         self.resume_checkpoint(self.output_dir)
    #
    #     for epoch in range(self.start_epoch, self.max_epoch + 1):
    #         time1 = time.perf_counter()
    #         self.train_epoch(epoch)
    #
    #
    #
    #         if self.use_ddp or self.use_fsdp:
    #             dist.barrier()
    #
    #         self._validate_epoch(epoch)
    #
    #         if self.use_ddp or self.use_fsdp:
    #             dist.barrier()
    #
    #
    #         if self.rank == 0:
    #             self._save_checkpoint(epoch)
    #
    #         if self.use_ddp or self.use_fsdp:
    #             dist.barrier()
    #
    #         self.scheduler.step()
    #
    #         time2 = time.perf_counter()
    #         time_escaped = (time2 - time1)/3600.0
    #         print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
    #
    #     if self.rank == 0:
    #         average_checkpoints(self.output_dir, self.avg_nbest_model)
    #
    #     if self.use_ddp or self.use_fsdp:
    #         dist.barrier()
    #
    #
    #     if writer:
    #         writer.close()
    #
    
    def train_epoch(self,
                model=None,
@@ -241,8 +241,8 @@
        Args:
            epoch (int): The current epoch number.
        """
        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
        model.train()
        
        # Set the number of steps for gradient accumulation
        accum_grad = self.accum_grad
@@ -289,6 +289,18 @@
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
            
                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                if "acc" in stats:
                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                if self.use_ddp or self.use_fsdp:
                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accum_grad == 0:
                # Perform gradient clipping if it is set
@@ -322,9 +334,11 @@
    
                speed_stats["total_time"] = total_time
                lr = scheduler.get_last_lr()[0]
                batch_num_epoch = -1
                if hasattr(dataloader_train, "__len__"):
                    batch_num_epoch = len(dataloader_train)
                self.log(epoch, batch_idx,
                         batch_num_epoch=len(dataloader_train),
                         batch_num_epoch=batch_num_epoch,
                         lr=lr,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
@@ -341,7 +355,7 @@
                    writer=writer
                )
            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
        
@@ -364,6 +378,7 @@
        Args:
            epoch (int): The current epoch number.
        """
        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
        model.eval()
        
        with torch.no_grad():
@@ -394,18 +409,35 @@
                loss = loss
                time4 = time.perf_counter()
                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                if "acc" in stats:
                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                if self.use_ddp or self.use_fsdp:
                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                
                batch_num_epoch = -1
                if hasattr(dataloader_val, "__len__"):
                    batch_num_epoch = len(dataloader_val)
                self.log(epoch, batch_idx,
                         batch_num_epoch=len(dataloader_val),
                         batch_num_epoch=batch_num_epoch,
                         lr=0.0,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
                         stats=stats,
                         writer=writer,
                         tag="train",
                         tag="val",
                         )
        self.val_acc_list.append(self.val_acc_avg)
        model.train()
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        
        
    def log(self,
@@ -422,39 +454,47 @@
        
        if (batch_idx + 1) % self.log_interval == 0:
            
            gpu_info = "GPU, memory: {:.3f} GB, " \
                       "{:.3f} GB, " \
                       "{:.3f} GB, " \
                       "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
                       "peak: {:.3f} GB, " \
                       "cache: {:.3f} GB, " \
                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
                                          )
            
            time_now = datetime.now()
            time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
            description = (
                f"{time_now}, "
                f"{tag}, "
                f"rank: {self.local_rank}, "
                f"epoch: {epoch}/{self.max_epoch}, "
                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
                f"(loss: {loss:.3f}), "
                f"(loss_avg_rank: {loss:.3f}), "
                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
                f"(lr: {lr:.3e}), "
                f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                f"{speed_stats}, "
                f"{gpu_info}"
            )
            logging.info(description)
            
            if writer is not None:
                writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                for key, var in stats.items():
                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                for key, var in speed_stats.items():
                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
        
    def close(self, writer=None):
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        if writer is not None:
            writer.close()