游雁
2024-03-24 873cfae5c347b940e38e853d8579a6b4e85ada05
funasr/train_utils/trainer.py
@@ -1,14 +1,27 @@
import math
import os
import time
import torch
import logging
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
from contextlib import nullcontext
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
from pathlib import Path
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
@contextmanager
def maybe_autocast(enabled):
    if enabled:
        with autocast():
            yield
    else:
        yield
class Trainer:
    """
@@ -26,14 +39,12 @@
        resume (str, optional): Path to a checkpoint to resume training from.
    """
    
    def __init__(self, model,
                 optim,
                 scheduler,
                 dataloader_train,
                 dataloader_val,
    def __init__(self,
                 local_rank,
                 use_ddp=False,
                 use_fsdp=False,
                 use_ddp: bool = False,
                 use_fsdp: bool = False,
                 use_fp16: bool = False,
                 output_dir: str="./",
                 **kwargs):
        """
        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
@@ -50,23 +61,32 @@
                      resume (str, optional): The file path to a checkpoint to resume training from.
        """
        
        self.model = model
        self.optim = optim
        self.scheduler = scheduler
        self.dataloader_train = dataloader_train
        self.dataloader_val = dataloader_val
        self.output_dir = kwargs.get('output_dir', './')
        self.output_dir = output_dir
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)
        self.resume = kwargs.get('resume', True)
        self.start_epoch = 0
        self.max_epoch = kwargs.get('max_epoch', 100)
        self.local_rank = local_rank
        self.use_ddp = use_ddp
        self.use_fsdp = use_fsdp
        self.device = next(model.parameters()).device
        self.kwargs = kwargs
        self.device = kwargs.get('device', "cuda")
        self.avg_nbest_model = kwargs.get("avg_nbest_model", 5)
        # self.kwargs = kwargs
        self.log_interval = kwargs.get("log_interval", 50)
        self.batch_total = 0
        self.use_fp16 = use_fp16
        self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
        # scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
        # scaler = ShardedGradScaler(enabled=use_fp16) if use_fsdp else scaler
        # self.scaler = scaler
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        self.keep_nbest_models = kwargs.get("keep_nbest_models", -1)
        self.accum_grad = kwargs.get("accum_grad", 1)
        self.grad_clip = kwargs.get("grad_clip", 10.0)
        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
        self.validate_interval = kwargs.get("validate_interval", 5000)
        
        if self.resume:
            self._resume_checkpoint(self.resume)
    
        try:
            rank = dist.get_rank()
@@ -77,8 +97,22 @@
            logging.warning("distributed is not initialized, only single shard")
        self.rank = rank
        self.world_size = world_size
    def _save_checkpoint(self, epoch):
        self.train_acc_avg = 0.0
        self.train_loss_avg = 0.0
        self.val_acc_avg = 0.0
        self.val_loss_avg = 0.0
        self.best_acc_idx = 0
        self.saved_ckpts = {}
        self.val_acc_list = []
        self.step_or_epoch = -1
    def save_checkpoint(self, epoch,
                        step=None,
                        model=None,
                        optim=None,
                        scheduler=None,
                        scaler=None,
                        ):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
@@ -87,19 +121,65 @@
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optim.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
        torch.save(state, filename)
        print(f'Checkpoint saved to {filename}')
        if self.rank == 0:
            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            self.step_or_epoch += 1
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optim.state_dict(),
                'scheduler': scheduler.state_dict(),
                "acc": self.val_acc_list,
                "step_or_epoch": self.step_or_epoch,
            }
            if hasattr(model, "module"):
                state["state_dict"] = model.module.state_dict()
            if scaler:
                state["scaler_state"] = scaler.state_dict()
            # Create output directory if it does not exist
            os.makedirs(self.output_dir, exist_ok=True)
            if step is None:
                ckpt_name = f'model.pt.ep{epoch}'
            else:
                ckpt_name = f'model.pt.ep{epoch}.{step}'
            filename = os.path.join(self.output_dir, ckpt_name)
            torch.save(state, filename)
            logging.info(f'\nCheckpoint saved to {filename}\n')
            latest = Path(os.path.join(self.output_dir, f'model.pt'))
            torch.save(state, latest)
            if self.val_acc_list[self.step_or_epoch] >= self.val_acc_list[self.best_acc_idx]:
                self.best_acc_idx = self.step_or_epoch
                best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
                torch.save(state, best_ckpt)
                logging.info(f"Update best acc: {self.val_acc_list[self.best_acc_idx]}, {best_ckpt}")
            else:
                logging.info(f"No improvement in acc: {self.val_acc_list[self.best_acc_idx]}")
            if self.keep_nbest_models > 0:
                self.saved_ckpts[ckpt_name] = self.val_acc_list[-1]
                if len(self.saved_ckpts) > self.keep_nbest_models:
                    min_key = min(self.saved_ckpts, key=self.saved_ckpts.get)
                    if min_key in self.saved_ckpts:
                        del self.saved_ckpts[min_key]
                    filename = os.path.join(self.output_dir, min_key)
                    logging.info(f"Delete: {filename}")
                    if os.path.exists(filename):
                        os.remove(filename)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    
    def _resume_checkpoint(self, resume_path):
    def resume_checkpoint(self,
                          model=None,
                          optim=None,
                          scheduler=None,
                          scaler=None,
                          ):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.
@@ -107,55 +187,228 @@
        Args:
            resume_path (str): The file path to the checkpoint to resume from.
        """
        if os.path.isfile(resume_path):
            checkpoint = torch.load(resume_path)
            self.start_epoch = checkpoint['epoch'] + 1
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
        else:
            print(f"No checkpoint found at '{resume_path}', starting from scratch")
    def run(self):
        """
        Starts the training process, iterating over epochs, training the model,
        and saving checkpoints at the end of each epoch.
        """
        for epoch in range(self.start_epoch, self.max_epoch + 1):
            self._train_epoch(epoch)
            # self._validate_epoch(epoch)
            if self.rank == 0:
                self._save_checkpoint(epoch)
            self.scheduler.step()
        if self.resume:
            ckpt = os.path.join(self.output_dir, "model.pt")
            if os.path.isfile(ckpt):
                checkpoint = torch.load(ckpt)
                self.start_epoch = checkpoint['epoch'] + 1
                # self.model.load_state_dict(checkpoint['state_dict'])
                src_state = checkpoint['state_dict']
                dst_state = model.state_dict()
                for k in dst_state.keys():
                    if not k.startswith("module.") and "module."+k in src_state.keys():
                        k_ddp = "module."+k
                    elif k.startswith("module.") and "module."+k not in src_state.keys():
                        k_ddp = k.replace("module.", "", 1)
                    else:
                        k_ddp = k
                    if k_ddp in src_state.keys():
                        dst_state[k] = src_state[k_ddp]
                    else:
                        print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
    
    def _train_epoch(self, epoch):
                model.load_state_dict(dst_state)
                optim.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                if scaler is not None and 'scaler_state' in checkpoint:
                    scaler.load_state_dict(checkpoint['scaler_state'])
                self.val_acc_list = checkpoint["acc"]
                self.step_or_epoch = checkpoint["step_or_epoch"]
                print(f"Checkpoint loaded successfully from '{ckpt}'")
            else:
                print(f"No checkpoint found at '{ckpt}', does not resume status!")
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    def train_epoch(self,
                model=None,
                optim=None,
                scheduler=None,
                scaler=None,
                dataloader_train=None,
                dataloader_val=None,
                epoch=None,
                writer=None,
                    ):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
            epoch (int): The current epoch number.
        """
        self.model.train()
        pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
                    dynamic_ncols=True)
        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
        model.train()
        # Set the number of steps for gradient accumulation
        accum_grad = self.kwargs.get("accum_grad", 1)
        accum_grad = self.accum_grad
        # Initialize the gradient accumulation
        self.optim.zero_grad()
        optim.zero_grad()
        speed_stats = {}
        time5 = time.perf_counter()
        for batch_idx, batch in enumerate(self.dataloader_train):
        iterator_stop = torch.tensor(0).to(self.device)
        dist.barrier()
        print(f"before iter, iterator_stop: {iterator_stop}\n")
        dataloader_train.batch_sampler.set_epoch(epoch)
        for batch_idx, batch in enumerate(dataloader_train):
            if self.use_ddp or self.use_fsdp:
                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                if iterator_stop > 0:
                    break
            self.batch_total += 1
            time1 = time.perf_counter()
            speed_stats["data_load"] = f"{time1-time5:0.3f}"
            # import pdb;
            # pdb.set_trace()
            batch = to_device(batch, self.device)
            
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            my_context = model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                retval = self.model(**batch)
                with maybe_autocast(self.use_fp16):
                    retval = model(**batch)
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
                stats = {k: v for k, v in stats.items() if v is not None}
                if self.use_ddp or self.use_fsdp:
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()
                    # if distributed, this method can also apply all_reduce()
                    # stats, weight = recursive_average(stats, weight, distributed=True)
                    if self.use_ddp or self.use_fsdp:
                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                    # Now weight is summation over all workers
                    loss /= weight.sum() # shape:[1] -> shape:[]
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss / accum_grad
                if self.use_fp16:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                if "acc" in stats:
                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                # if self.use_ddp or self.use_fsdp:
                #     train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
                #     train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
                #     dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
                #     dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
                #     self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
                #     self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accum_grad == 0:
                # Perform gradient clipping if it is set
                if self.grad_clip > 0:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_norm=self.grad_clip,
                        norm_type=self.grad_clip_type,
                    )
                    if not torch.isfinite(grad_norm):
                        logging.warning(
                            f"The grad norm is {grad_norm}. Skipping updating the model."
                        )
                        optim.zero_grad()  # Reset gradients
                        continue
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
                if self.use_fp16:
                    scaler.step(optim)
                    scaler.update()
                else:
                    optim.step()
                scheduler.step()
                # Clear gradients for the next accumulation stage
                optim.zero_grad(set_to_none=True)
                total_time = f"{time.perf_counter() - time5:0.3f}"
                time5 = time.perf_counter()
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
                speed_stats["total_time"] = total_time
                lr = scheduler.get_last_lr()[0]
                batch_num_epoch = -1
                if hasattr(dataloader_train, "__len__"):
                    batch_num_epoch = len(dataloader_train)
                self.log(epoch, batch_idx,
                         batch_num_epoch=batch_num_epoch,
                         lr=lr,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
                         stats=stats,
                         writer=writer,
                         tag="train",
                         )
            if (batch_idx + 1) % self.validate_interval == 0:
                self.validate_epoch(
                    model=model,
                    dataloader_val=dataloader_val,
                    epoch=epoch,
                    writer=writer
                )
            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
        else:
            if self.use_ddp or self.use_fsdp:
                iterator_stop.fill_(1)
                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    def validate_epoch(self,
                       model=None,
                       dataloader_val=None,
                       epoch=None,
                       writer=None,
                       **kwargs,
                       ):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.
        Args:
            epoch (int): The current epoch number.
        """
        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
        model.eval()
        with torch.no_grad():
            speed_stats = {}
            time5 = time.perf_counter()
            iterator_stop = torch.tensor(0).to(self.device)
            dist.barrier()
            print(f"before iter, iterator_stop: {iterator_stop}\n")
            for batch_idx, batch in enumerate(dataloader_val):
                if self.use_ddp or self.use_fsdp:
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
                    if epoch >= 1:
                        print(f"iterator_stop: {iterator_stop}\n")
                    if iterator_stop > 0:
                        break
                time1 = time.perf_counter()
                speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                batch = to_device(batch, self.device)
                time2 = time.perf_counter()
                retval = model(**batch)
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
@@ -165,71 +418,112 @@
                    loss = (loss * weight.type(loss.dtype)).sum()
                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight, distributed=True)
                    if self.use_ddp or self.use_fsdp:
                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                    # Now weight is summation over all workers
                    loss /= weight
                    loss /= weight.sum() # shape:[1] -> shape:[]
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss / accum_grad
                loss.backward()
                loss = loss
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
                # Perform gradient clipping if it is set
                if self.kwargs.get("grad_clip", None) is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        max_norm=self.kwargs.get("grad_clip", 10.0),
                        norm_type=self.kwargs.get("grad_clip_type", 2.0),
                    )
                    if not torch.isfinite(grad_norm):
                        logging.warning(
                            f"The grad norm is {grad_norm}. Skipping updating the model."
                        )
                        self.optim.zero_grad()  # Reset gradients
                        continue
                # Execute an optimization step (update model parameters)
                self.optim.step()
                self.scheduler.step()
                # Clear gradients for the next accumulation stage
                self.optim.zero_grad()
                total_time = f"{time.perf_counter() - time5:0.3f}"
                time5 = time.perf_counter()
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
                speed_stats["total_time"] = total_time
            # import pdb;
            # pdb.set_trace()
            pbar.update(1)
            if self.local_rank == 0:
                description = (
                    f"Epoch: {epoch + 1}/{self.max_epoch}, "
                    f"step {batch_idx}/{len(self.dataloader_train)}, "
                    f"{speed_stats}, "
                    f"(loss: {loss.detach().cpu().item():.3f}), "
                    f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
                )
                pbar.set_description(description)
            # if batch_idx == 2:
            #     break
        pbar.close()
    def _validate_epoch(self, epoch):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.
                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                if "acc" in stats:
                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                # if self.use_ddp or self.use_fsdp:
                #     val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
                #     val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
                #     dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
                #     dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
                #     self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
                #     self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
                batch_num_epoch = -1
                if hasattr(dataloader_val, "__len__"):
                    batch_num_epoch = len(dataloader_val)
                self.log(epoch, batch_idx,
                         batch_num_epoch=batch_num_epoch,
                         lr=0.0,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
                         stats=stats,
                         writer=writer,
                         tag="val",
                         )
            else:
                if self.use_ddp or self.use_fsdp:
                    iterator_stop.fill_(1)
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
        self.val_acc_list.append(self.val_acc_avg)
        model.train()
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    def log(self,
            epoch=0,
            batch_idx=0,
            batch_num_epoch=-1,
            lr=0.0,
            loss=0.0,
            speed_stats=None,
            stats=None,
            writer=None,
            tag="train",
            ):
        if (batch_idx + 1) % self.log_interval == 0:
            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
                       "peak: {:.3f} GB, " \
                       "cache: {:.3f} GB, " \
                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
                                          )
            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
            description = (
                f"{tag}, "
                f"rank: {self.local_rank}, "
                f"epoch: {epoch}/{self.max_epoch}, "
                f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
                f"(loss_avg_rank: {loss:.3f}), "
                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), "
                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
                f"(lr: {lr:.3e}), "
                f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                f"{speed_stats}, "
                f"{gpu_info}"
            )
            logging.info(description)
            if writer is not None:
                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                for key, var in stats.items():
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                for key, var in speed_stats.items():
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
    def close(self, writer=None):
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        if writer is not None:
            writer.close()
    
        Args:
            epoch (int): The current epoch number.
        """
        self.model.eval()
        with torch.no_grad():
            for data, target in self.dataloader_val:
                # Implement the model validation steps here
                pass
        if self.use_ddp or self.use_fsdp:
            torch.distributed.destroy_process_group()