From 3ac03e448b7673604eb86f619b27521fca55f34d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 01:36:39 +0800
Subject: [PATCH] train & finetune llm-asr (#1519)

---
 funasr/models/paraformer/model.py                                 |    1 
 funasr/train_utils/trainer_llm.py                                 |  178 ++++++++++++-------
 funasr/schedulers/lambdalr_cus.py                                 |   17 +
 funasr/train_utils/trainer.py                                     |    1 
 funasr/train_utils/average_nbest_models.py                        |    2 
 funasr/datasets/llm_datasets_vicuna/samplers.py                   |  197 +++++++++++++++++++++
 funasr/bin/train_llm.py                                           |   78 ++++++--
 examples/industrial_data_pretraining/paraformer_streaming/demo.py |    7 
 8 files changed, 386 insertions(+), 95 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 9885c0b..601a531 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -8,20 +8,21 @@
 chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
 encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention
 decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention
-
+wav_file="/Users/zhifu/Downloads/NCYzUhAtZNI_0015.wav"
 model = AutoModel(model="iic/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.4")
-res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+res = model.generate(input=wav_file,
             chunk_size=chunk_size,
             encoder_chunk_look_back=encoder_chunk_look_back,
             decoder_chunk_look_back=decoder_chunk_look_back,
             )
 print(res)
 
+# exit()
 
 import soundfile
 import os
 
-wav_file = os.path.join(model.model_path, "example/asr_example.wav")
+# wav_file = os.path.join(model.model_path, "example/asr_example.wav")
 speech, sample_rate = soundfile.read(wav_file)
 
 chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
diff --git a/funasr/bin/train_llm.py b/funasr/bin/train_llm.py
index 8742bf1..89f5db3 100644
--- a/funasr/bin/train_llm.py
+++ b/funasr/bin/train_llm.py
@@ -4,18 +4,21 @@
 import os
 import sys
 import torch
+import torch.nn as nn
 import hydra
 import logging
 import time
 import argparse
 from io import BytesIO
 
+from contextlib import nullcontext
 import torch.distributed as dist
 from collections.abc import Sequence
 from omegaconf import DictConfig, OmegaConf
 from torch.cuda.amp import autocast, GradScaler
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.algorithms.join import Join
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from funasr.train_utils.average_nbest_models import average_checkpoints
 
@@ -48,7 +51,6 @@
 
 
 def main(**kwargs):
-    print(kwargs)
     
     # set random seed
     set_all_random_seed(kwargs.get("seed", 0))
@@ -61,11 +63,13 @@
         tables.print()
     # Check if we are using DDP or FSDP
     use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
-    use_fsdp = kwargs.get("use_fsdp", None)
+    use_fsdp = kwargs.get("use_fsdp", False)
+    # use_ddp = False if use_fsdp else use_fsdp
     if use_ddp or use_fsdp:
         dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
         torch.cuda.set_device(local_rank)
-        
+
+    logging.info("Build model, frontend, tokenizer")
     device = kwargs.get("device", "cuda")
     kwargs["device"] = "cpu"
     model = AutoModel(**kwargs)
@@ -76,6 +80,7 @@
         os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
         yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
         OmegaConf.save(config=kwargs, f=yaml_file)
+        print(kwargs)
         logging.info("config.yaml is saved to: %s", yaml_file)
     
     # parse kwargs
@@ -105,19 +110,42 @@
         model = DDP(model, device_ids=[local_rank],
                     find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
     elif use_fsdp:
-        model = FSDP(model).cuda(local_rank)
+        # model = FSDP(model).cuda(local_rank)
+
+        def custom_auto_wrap_policy(
+            module: nn.Module,
+            recurse: bool,
+            nonwrapped_numel: int,
+            # Additional custom arguments
+            min_num_params: int = int(1e8),
+        ) -> bool:
+            # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
+            is_large = unwrapped_params >= min_num_params
+            requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
+            return is_large and requires_grad_uniform
+
+        # Configure a custom `min_num_params`
+        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
+        torch.cuda.set_device(local_rank)
+        model = FSDP(model,
+                     auto_wrap_policy=custom_auto_wrap_policy,
+                     mixed_precision=None,
+                     device_id=torch.cuda.current_device())
     else:
         model = model.to(device=kwargs.get("device", "cuda"))
 
+    logging.info(f"{model}")
     kwargs["device"] = next(model.parameters()).device
         
     # optim
+    logging.info("Build optim")
     optim = kwargs.get("optim", "adam")
     assert optim in optim_classes
     optim_class = optim_classes.get(optim)
     optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
     
     # scheduler
+    logging.info("Build scheduler")
     scheduler = kwargs.get("scheduler", "warmuplr")
     assert scheduler in scheduler_classes
     scheduler_class = scheduler_classes.get(scheduler)
@@ -125,6 +153,7 @@
 
 
     # dataset
+    logging.info("Build dataloader")
     dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
     dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
     dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
@@ -142,8 +171,9 @@
 
     trainer = Trainer(local_rank=local_rank,
                       use_ddp=use_ddp,
-                      resume=kwargs.get("resume", True),
+                      use_fsdp=use_fsdp,
                       device=kwargs["device"],
+                      output_dir=kwargs.get("output_dir", "./exp"),
                       **kwargs.get("train_conf"),
                       )
 
@@ -159,20 +189,27 @@
         writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
     except:
         writer = None
-    
+
+    if use_ddp or use_fsdp:
+        context = Join([model])
+    else:
+        context = nullcontext()
+
     for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
         time1 = time.perf_counter()
-        trainer.train_epoch(
-                            model=model,
-                            optim=optim,
-                            scheduler=scheduler,
-                            scaler=scaler,
-                            dataloader_train=dataloader_tr,
-                            dataloader_val=dataloader_val,
-                            epoch=epoch,
-                            writer=writer
-                            )
-
+        with context:
+            
+            trainer.train_epoch(
+                                model=model,
+                                optim=optim,
+                                scheduler=scheduler,
+                                scaler=scaler,
+                                dataloader_train=dataloader_tr,
+                                dataloader_val=dataloader_val,
+                                epoch=epoch,
+                                writer=writer
+                                )
+        scheduler.step()
         trainer.validate_epoch(
             model=model,
             dataloader_val=dataloader_val,
@@ -180,21 +217,20 @@
             writer=writer
         )
 
+        
         trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
-
-        scheduler.step()
 
         time2 = time.perf_counter()
         time_escaped = (time2 - time1) / 3600.0
         logging.info(
-            f"\nrank: {local_rank}, "
+            f"rank: {local_rank}, "
             f"time_escaped_epoch: {time_escaped:.3f} hours, "
             f"estimated to finish {trainer.max_epoch} "
             f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
 
 
     if trainer.rank == 0:
-        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
+        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
 
     trainer.close()
 
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
index c728d9c..61f7d94 100644
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -232,3 +232,200 @@
 
     def set_epoch(self, epoch):
         self.epoch = epoch
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler_fn")
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+    dataloader_args = {}
+    dataloader_args["batch_sampler"] = CustomDistributedBufferBatchSampler(dataset, **kwargs)
+    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+    
+    return dataloader_args
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedBufferBatchSampler")
+class CustomDistributedBatchSampler(Sampler):
+    def __init__(self, dataset,
+                 batch_size,
+                 num_replicas=None,
+                 rank=None,
+                 shuffle=True,
+                 drop_last=False,
+                 is_training: bool = True,
+                 sort_size: int = 1024,
+                 **kwargs,
+                 ):
+        
+        try:
+            rank = dist.get_rank()
+            num_replicas = dist.get_world_size()
+        except:
+            rank = 0
+            num_replicas = 1
+        self.rank = rank
+        self.num_replicas = num_replicas
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.is_training = is_training
+        self.shuffle = shuffle and is_training
+        self.drop_last = drop_last
+        # self.total_size = len(dataset)
+        if self.drop_last:
+            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
+        else:
+            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
+        self.num_samples = int(self.total_size // self.num_replicas)
+        self.epoch = 0
+        self.max_token_length = kwargs.get("max_token_length", None)
+        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+        self.sort_size = sort_size
+    
+    def __iter__(self):
+        # Generate a list of indices
+        if self.shuffle:
+            g = torch.Generator()
+            g.manual_seed(self.epoch)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+        
+        # Add extra samples to make it evenly divisible
+        padding_size = self.total_size - len(indices)
+        if padding_size <= len(indices):
+            indices += indices[:padding_size]
+        else:
+            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
+        
+        assert len(indices) == self.total_size
+        
+        # Subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+        
+        # Filter out indices with length greater than the max length, if provided
+        if self.max_token_length is not None:
+            filtered_indices = []
+            for idx in indices:
+                source_len = self.dataset.get_source_len(idx) / self.length_scale_source
+                if source_len <= self.max_token_length:
+                    filtered_indices.append(idx)
+            indices = filtered_indices
+
+        # Buffer sorting logic
+        sorted_batches = []
+        buffer = []
+
+        for idx in indices:
+            buffer.append(idx)
+            if len(buffer) >= self.sort_size:
+                # Sort the buffer based on some criteria, e.g., dataset sample length
+                buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+                sorted_batches.extend(self._create_batches_from_buffer(buffer))
+                buffer = []
+
+        # Handle the remaining items in the buffer
+        if buffer:
+            buffer.sort(key=lambda x: self.dataset.get_source_len(x))
+            sorted_batches.extend(self._create_batches_from_buffer(buffer))
+
+        return iter(sorted_batches)
+
+    def _create_batches_from_buffer(self, buffer):
+        # Function to convert the sorted buffer into batches
+        batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
+        if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
+            batched_buffer = batched_buffer[:-1]
+        return batched_buffer
+    
+    def __len__(self):
+        
+        return self.num_samples // self.batch_size
+    
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler_fn")
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+    dataloader_args = {}
+    dataloader_args["batch_sampler"] = CustomDistributedDynamicBatchSampler(dataset, **kwargs)
+    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+    
+    return dataloader_args
+
+
+@tables.register("batch_sampler_classes", "CustomDistributedDynamicBatchSampler")
+class CustomDistributedDynamicBatchSampler(Sampler):
+    def __init__(self, dataset,
+                 batch_size,
+                 num_replicas=None,
+                 rank=None,
+                 shuffle=True,
+                 drop_last=False,
+                 is_training: bool = True,
+                 **kwargs,
+                 ):
+        
+        try:
+            rank = dist.get_rank()
+            num_replicas = dist.get_world_size()
+        except:
+            rank = 0
+            num_replicas = 1
+        self.rank = rank
+        self.num_replicas = num_replicas
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.is_training = is_training
+        self.shuffle = shuffle and is_training
+        self.drop_last = drop_last
+        
+        self.total_size = len(self.dataset)
+        # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
+        self.epoch = 0
+    
+    def __iter__(self):
+        if self.shuffle:
+            g = torch.Generator()
+            g.manual_seed(self.epoch)
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()
+        else:
+            indices = list(range(len(self.dataset)))
+        
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        
+        batches = []
+        batch = []
+        max_len_in_batch = 0
+        current_batch_length = 0
+        
+        for idx in indices:
+            sample_length = self.dataset.get_source_len(idx)
+            potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
+                    len(batch) + 1)
+            
+            if potential_batch_length <= self.batch_size:
+                batch.append(idx)
+                if sample_length > max_len_in_batch:
+                    max_len_in_batch = sample_length
+                    current_batch_length = max_len_in_batch * len(batch)
+            else:
+                batches.append(batch)
+                batch = [idx]
+                max_len_in_batch = sample_length
+                current_batch_length = max_len_in_batch
+        
+        # Add the last batch if it's not empty and we're not dropping it
+        if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
+            batches.append(batch)
+        
+        return iter(batches)
+    
+    def __len__(self):
+        
+        return -1
+    
+    def set_epoch(self, epoch):
+        self.epoch = epoch
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 316255d..bd85df0 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -231,6 +231,7 @@
         stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
         
         stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
         
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
diff --git a/funasr/schedulers/lambdalr_cus.py b/funasr/schedulers/lambdalr_cus.py
index 0123cc2..5aad049 100644
--- a/funasr/schedulers/lambdalr_cus.py
+++ b/funasr/schedulers/lambdalr_cus.py
@@ -14,4 +14,19 @@
                 for base_lr in self.base_lrs
             ]
         else:
-            return [base_lr for base_lr in self.base_lrs]
\ No newline at end of file
+            return [base_lr for base_lr in self.base_lrs]
+        
+        
+class CustomLambdaLR(_LRScheduler):
+    def __init__(self, optimizer, train_config, last_epoch=-1, verbose=False):
+        self.warmup_steps = train_config.warmup_steps
+        self.total_steps = train_config.total_steps
+        super(CustomLambdaLR, self).__init__(optimizer, last_epoch, verbose)
+
+    def get_lr(self):
+        step = self._step_count
+        if step < self.warmup_steps:
+            lr_scale = step / self.warmup_steps
+        else:
+            lr_scale = max(0.0, 1 - (step - self.warmup_steps) / (self.total_steps - self.warmup_steps))
+        return [base_lr * lr_scale for base_lr in self.base_lrs]
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index f117804..3603a44 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@
     return checkpoint_paths
 
 @torch.no_grad()
-def average_checkpoints(output_dir: str, last_n: int=5):
+def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
     """
     Average the last 'last_n' checkpoints' model state_dicts.
     If a tensor is of type torch.int, perform sum instead of average.
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 14abd6c..aae4513 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -103,6 +103,7 @@
         
         os.makedirs(os.path.join(self.output_dir, "tensorboard"), exist_ok=True)
         self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
+
         
     
     def _save_checkpoint(self, epoch, step=None):
diff --git a/funasr/train_utils/trainer_llm.py b/funasr/train_utils/trainer_llm.py
index 6a3b83b..5f13b5a 100644
--- a/funasr/train_utils/trainer_llm.py
+++ b/funasr/train_utils/trainer_llm.py
@@ -1,3 +1,4 @@
+import math
 import os
 import time
 import torch
@@ -61,6 +62,8 @@
         """
         
         self.output_dir = output_dir
+        if not os.path.exists(self.output_dir):
+            os.makedirs(self.output_dir, exist_ok=True)
         self.resume = kwargs.get('resume', True)
         self.start_epoch = 0
         self.max_epoch = kwargs.get('max_epoch', 100)
@@ -78,6 +81,7 @@
         # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
         # self.scaler = scaler
         self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
+        self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
         self.accum_grad = kwargs.get("accum_grad", 1)
         self.grad_clip = kwargs.get("grad_clip", 10.0)
         self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
@@ -93,6 +97,15 @@
             logging.warning("distributed is not initialized, only single shard")
         self.rank = rank
         self.world_size = world_size
+        self.train_acc_avg = 0.0
+        self.train_loss_avg = 0.0
+        self.val_acc_avg = 0.0
+        self.val_loss_avg = 0.0
+        self.best_acc_idx = 0
+        self.saved_ckpts = {}
+        self.val_acc_list = []
+        self.step_or_epoch = -1
+        
         
 
         
@@ -112,28 +125,56 @@
         Args:
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
+        
         if self.rank == 0:
+            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
+            self.step_or_epoch += 1
             state = {
                 'epoch': epoch,
                 'state_dict': model.state_dict(),
                 'optimizer': optim.state_dict(),
                 'scheduler': scheduler.state_dict(),
+                "acc": self.val_acc_list,
+                "step_or_epoch": self.step_or_epoch,
             }
+            if hasattr(model, "module"):
+                state["state_dict"] = model.module.state_dict()
+                
             if scaler:
                 state["scaler_state"] = scaler.state_dict()
             # Create output directory if it does not exist
             os.makedirs(self.output_dir, exist_ok=True)
             if step is None:
-                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+                ckpt_name = f'model.pt.ep{epoch}'
             else:
-                filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
-            
+                ckpt_name = f'model.pt.ep{epoch}.{step}'
+            filename = os.path.join(self.output_dir, ckpt_name)
             torch.save(state, filename)
             
-            print(f'\nCheckpoint saved to {filename}\n')
+            logging.info(f'\nCheckpoint saved to {filename}\n')
             latest = Path(os.path.join(self.output_dir, f'model.pt'))
             torch.save(state, latest)
-        
+            
+            if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
+                self.best_acc_idx = self.step_or_epoch
+                best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
+                torch.save(state, best_ckpt)
+                logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
+            else:
+                logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
+            
+            if self.keep_nbest_models > 0:
+                self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
+                if len(self.saved_ckpts) > self.keep_nbest_models:
+
+                    min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+                    if min_key in self.saved_ckpts:
+                        del self.saved_ckpts[min_key]
+                    filename = os.path.join(self.output_dir, min_key)
+                    logging.info(f"Delete: {filename}")
+                    if os.path.exists(filename):
+                        os.remove(filename)
+                
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
     
@@ -173,6 +214,10 @@
                 scheduler.load_state_dict(checkpoint['scheduler'])
                 if scaler is not None and 'scaler_state' in checkpoint:
                     scaler.load_state_dict(checkpoint['scaler_state'])
+                
+                self.val_acc_list = checkpoint["acc"]
+                self.step_or_epoch = checkpoint["step_or_epoch"]
+                
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
             else:
                 print(f"No checkpoint found at '{ckpt}', does not resume status!")
@@ -180,52 +225,7 @@
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
         
-    # def train(self):
-    #     """
-    #     Starts the training process, iterating over epochs, training the model,
-    #     and saving checkpoints at the end of each epoch.
-    #     """
-    #     if self.resume:
-    #         self.resume_checkpoint(self.output_dir)
-    #
-    #     for epoch in range(self.start_epoch, self.max_epoch + 1):
-    #         time1 = time.perf_counter()
-    #         self.train_epoch(epoch)
-    #
-    #
-    #
-    #         if self.use_ddp or self.use_fsdp:
-    #             dist.barrier()
-    #
-    #         self._validate_epoch(epoch)
-    #
-    #         if self.use_ddp or self.use_fsdp:
-    #             dist.barrier()
-    #
-    #
-    #         if self.rank == 0:
-    #             self._save_checkpoint(epoch)
-    #
-    #         if self.use_ddp or self.use_fsdp:
-    #             dist.barrier()
-    #
-    #         self.scheduler.step()
-    #
-    #         time2 = time.perf_counter()
-    #         time_escaped = (time2 - time1)/3600.0
-    #         print(f"\nrank: {self.local_rank}, time_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f} hours\n")
-    #
-    #     if self.rank == 0:
-    #         average_checkpoints(self.output_dir, self.avg_nbest_model)
-    #
-    #     if self.use_ddp or self.use_fsdp:
-    #         dist.barrier()
-    #
-    #
-    #     if writer:
-    #         writer.close()
-    #
-    
+ 
     def train_epoch(self,
                 model=None,
                 optim=None,
@@ -241,9 +241,9 @@
         Args:
             epoch (int): The current epoch number.
         """
+        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
         model.train()
 
-        
         # Set the number of steps for gradient accumulation
         accum_grad = self.accum_grad
         # Initialize the gradient accumulation
@@ -288,6 +288,18 @@
                     loss.backward()
                 time4 = time.perf_counter()
                 speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+                
+                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+                if "acc" in stats:
+                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+                if self.use_ddp or self.use_fsdp:
+                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                
             
             # Perform an optimizer step only after accumulating enough gradients
             if (batch_idx + 1) % accum_grad == 0:
@@ -322,9 +334,11 @@
     
                 speed_stats["total_time"] = total_time
                 lr = scheduler.get_last_lr()[0]
-
+                batch_num_epoch = -1
+                if hasattr(dataloader_train, "__len__"):
+                    batch_num_epoch = len(dataloader_train)
                 self.log(epoch, batch_idx,
-                         batch_num_epoch=len(dataloader_train),
+                         batch_num_epoch=batch_num_epoch,
                          lr=lr,
                          loss=loss.detach().cpu().item(),
                          speed_stats=speed_stats,
@@ -341,7 +355,7 @@
                     writer=writer
                 )
 
-            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                 self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
 
         
@@ -364,6 +378,7 @@
         Args:
             epoch (int): The current epoch number.
         """
+        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
         model.eval()
         
         with torch.no_grad():
@@ -394,18 +409,35 @@
                 loss = loss
                 time4 = time.perf_counter()
 
+                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
+                if "acc" in stats:
+                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
+                if self.use_ddp or self.use_fsdp:
+                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
+                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
+                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                 
+                batch_num_epoch = -1
+                if hasattr(dataloader_val, "__len__"):
+                    batch_num_epoch = len(dataloader_val)
                 self.log(epoch, batch_idx,
-                         batch_num_epoch=len(dataloader_val),
+                         batch_num_epoch=batch_num_epoch,
                          lr=0.0,
                          loss=loss.detach().cpu().item(),
                          speed_stats=speed_stats,
                          stats=stats,
                          writer=writer,
-                         tag="train",
+                         tag="val",
                          )
 
+        self.val_acc_list.append(self.val_acc_avg)
         model.train()
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         
         
     def log(self,
@@ -422,39 +454,47 @@
         
         if (batch_idx + 1) % self.log_interval == 0:
             
-            gpu_info = "GPU, memory: {:.3f} GB, " \
-                       "{:.3f} GB, " \
-                       "{:.3f} GB, " \
-                       "{:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
+            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
+                       "peak: {:.3f} GB, " \
+                       "cache: {:.3f} GB, " \
+                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
                                           torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
                                           torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
                                           torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
                                           )
             
-            time_now = datetime.now()
-            time_now = time_now.strftime("%Y-%m-%d %H:%M:%S")
+            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
+            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
             description = (
-                f"{time_now}, "
+                f"{tag}, "
                 f"rank: {self.local_rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
                 f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
-                f"(loss: {loss:.3f}), "
+                f"(loss_avg_rank: {loss:.3f}), "
+                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
+                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
                 f"(lr: {lr:.3e}), "
-                f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, "
+                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                 f"{speed_stats}, "
                 f"{gpu_info}"
             )
             logging.info(description)
             
             if writer is not None:
-                writer.add_scalar(f'rank{self.local_rank}_Loss/{tag}', loss, self.batch_total)
+                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
+                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                 writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                 for key, var in stats.items():
-                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
+                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                 for key, var in speed_stats.items():
-                    writer.add_scalar(f'rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
+                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
         
     def close(self, writer=None):
+        
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
+        
         if writer is not None:
             writer.close()
     

--
Gitblit v1.9.1