zhifu gao
2024-04-25 80bd14e6bbb7bb282ff3832194648dc4a16157ca
funasr/train_utils/trainer.py
@@ -15,6 +15,7 @@
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
@contextmanager
def maybe_autocast(enabled):
    if enabled:
@@ -22,6 +23,7 @@
            yield
    else:
        yield
class Trainer:
    """
@@ -38,14 +40,16 @@
        output_dir (str): Directory where model checkpoints will be saved.
        resume (str, optional): Path to a checkpoint to resume training from.
    """
    def __init__(self,
                 local_rank,
                 use_ddp: bool = False,
                 use_fsdp: bool = False,
                 use_fp16: bool = False,
                 output_dir: str="./",
                 **kwargs):
    def __init__(
        self,
        local_rank,
        use_ddp: bool = False,
        use_fsdp: bool = False,
        use_fp16: bool = False,
        output_dir: str = "./",
        **kwargs,
    ):
        """
        Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
@@ -60,17 +64,17 @@
                      output_dir (str): The directory where model checkpoints will be saved. Default is './'.
                      resume (str, optional): The file path to a checkpoint to resume training from.
        """
        self.output_dir = output_dir
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)
        self.resume = kwargs.get('resume', True)
        self.resume = kwargs.get("resume", True)
        self.start_epoch = 0
        self.max_epoch = kwargs.get('max_epoch', 100)
        self.max_epoch = kwargs.get("max_epoch", 100)
        self.local_rank = local_rank
        self.use_ddp = use_ddp
        self.use_fsdp = use_fsdp
        self.device = kwargs.get('device', "cuda")
        self.device = kwargs.get("device", "cuda")
        # self.kwargs = kwargs
        self.log_interval = kwargs.get("log_interval", 50)
        self.batch_total = 0
@@ -83,9 +87,7 @@
        self.accum_grad = kwargs.get("accum_grad", 1)
        self.grad_clip = kwargs.get("grad_clip", 10.0)
        self.grad_clip_type = kwargs.get("grad_clip_type", 2.0)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
@@ -105,14 +107,19 @@
        self.best_step_or_epoch = ""
        self.val_acc_step_or_eoch = {}
        self.val_loss_step_or_eoch = {}
    def save_checkpoint(self, epoch,
                        step=None,
                        model=None,
                        optim=None,
                        scheduler=None,
                        scaler=None,
                        ):
        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
    def save_checkpoint(
        self,
        epoch,
        step=None,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
    ):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
@@ -121,15 +128,15 @@
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        if self.rank == 0:
            logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
            # self.step_or_epoch += 1
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optim.state_dict(),
                'scheduler': scheduler.state_dict(),
                "epoch": epoch,
                "state_dict": model.state_dict(),
                "optimizer": optim.state_dict(),
                "scheduler": scheduler.state_dict(),
                "saved_ckpts": self.saved_ckpts,
                "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
                "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
@@ -138,43 +145,59 @@
            }
            if hasattr(model, "module"):
                state["state_dict"] = model.module.state_dict()
            if scaler:
                state["scaler_state"] = scaler.state_dict()
            # Create output directory if it does not exist
            os.makedirs(self.output_dir, exist_ok=True)
            if step is None:
                ckpt_name = f'model.pt.ep{epoch}'
                ckpt_name = f"model.pt.ep{epoch}"
            else:
                ckpt_name = f'model.pt.ep{epoch}.{step}'
                ckpt_name = f"model.pt.ep{epoch}.{step}"
            filename = os.path.join(self.output_dir, ckpt_name)
            torch.save(state, filename)
            logging.info(f'\nCheckpoint saved to {filename}\n')
            latest = Path(os.path.join(self.output_dir, f'model.pt'))
            logging.info(f"\nCheckpoint saved to {filename}\n")
            latest = Path(os.path.join(self.output_dir, f"model.pt"))
            torch.save(state, latest)
            if self.best_step_or_epoch == "":
                self.best_step_or_epoch = ckpt_name
            if self.avg_keep_nbest_models_type == "acc":
                if self.val_acc_step_or_eoch[ckpt_name] >= self.val_acc_step_or_eoch[self.best_step_or_epoch]:
                if (
                    self.val_acc_step_or_eoch[ckpt_name]
                    >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
                ):
                    self.best_step_or_epoch = ckpt_name
                    best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
                    best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                    torch.save(state, best_ckpt)
                    logging.info(f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}")
                    logging.info(
                        f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                    )
                else:
                    logging.info(f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}")
                    logging.info(
                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}"
                    )
            elif self.avg_keep_nbest_models_type == "loss":
                if self.val_loss_step_or_eoch[ckpt_name] <= self.val_loss_step_or_eoch[self.best_step_or_epoch]:
                if (
                    self.val_loss_step_or_eoch[ckpt_name]
                    <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
                ):
                    self.best_step_or_epoch = ckpt_name
                    best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best'))
                    best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                    torch.save(state, best_ckpt)
                    logging.info(f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}")
                    logging.info(
                        f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                    )
                else:
                    logging.info(f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}")
                    logging.info(
                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}"
                    )
            else:
                print("Undo")
            self.saved_ckpts[ckpt_name] = getattr(self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch")[ckpt_name]
            self.saved_ckpts[ckpt_name] = getattr(
                self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
            )[ckpt_name]
            if self.keep_nbest_models > 0:
                if len(self.saved_ckpts) > self.keep_nbest_models:
                    if self.avg_keep_nbest_models_type == "acc":
@@ -187,16 +210,17 @@
                    logging.info(f"Delete: {filename}")
                    if os.path.exists(filename):
                        os.remove(filename)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    def resume_checkpoint(self,
                          model=None,
                          optim=None,
                          scheduler=None,
                          scaler=None,
                          ):
    def resume_checkpoint(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
    ):
        """
        Resumes training from a checkpoint at the given file path.
        Loads the model's state, the optimizer's state, and the scheduler's state.
@@ -208,14 +232,14 @@
            ckpt = os.path.join(self.output_dir, "model.pt")
            if os.path.isfile(ckpt):
                checkpoint = torch.load(ckpt, map_location="cpu")
                self.start_epoch = checkpoint['epoch'] + 1
                self.start_epoch = checkpoint["epoch"] + 1
                # self.model.load_state_dict(checkpoint['state_dict'])
                src_state = checkpoint['state_dict']
                src_state = checkpoint["state_dict"]
                dst_state = model.state_dict()
                for k in dst_state.keys():
                    if not k.startswith("module.") and "module."+k in src_state.keys():
                        k_ddp = "module."+k
                    elif k.startswith("module.") and "module."+k not in src_state.keys():
                    if not k.startswith("module.") and "module." + k in src_state.keys():
                        k_ddp = "module." + k
                    elif k.startswith("module.") and "module." + k not in src_state.keys():
                        k_ddp = k.replace("module.", "", 1)
                    else:
                        k_ddp = k
@@ -223,37 +247,47 @@
                        dst_state[k] = src_state[k_ddp]
                    else:
                        print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}")
                model.load_state_dict(dst_state)
                optim.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                if scaler is not None and 'scaler_state' in checkpoint:
                    scaler.load_state_dict(checkpoint['scaler_state'])
                optim.load_state_dict(checkpoint["optimizer"])
                scheduler.load_state_dict(checkpoint["scheduler"])
                if scaler is not None and "scaler_state" in checkpoint:
                    scaler.load_state_dict(checkpoint["scaler_state"])
                self.saved_ckpts = checkpoint["saved_ckpts"]
                self.val_acc_step_or_eoch = checkpoint["val_acc_step_or_eoch"] if "val_acc_step_or_eoch" in checkpoint else {}
                self.val_loss_step_or_eoch = checkpoint["val_loss_step_or_eoch"] if "val_loss_step_or_eoch" in checkpoint else {}
                self.best_step_or_epoch = checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
                self.val_acc_step_or_eoch = (
                    checkpoint["val_acc_step_or_eoch"]
                    if "val_acc_step_or_eoch" in checkpoint
                    else {}
                )
                self.val_loss_step_or_eoch = (
                    checkpoint["val_loss_step_or_eoch"]
                    if "val_loss_step_or_eoch" in checkpoint
                    else {}
                )
                self.best_step_or_epoch = (
                    checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
                )
                model.to(self.device)
                print(f"Checkpoint loaded successfully from '{ckpt}'")
            else:
                print(f"No checkpoint found at '{ckpt}', does not resume status!")
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    def train_epoch(self,
                model=None,
                optim=None,
                scheduler=None,
                scaler=None,
                dataloader_train=None,
                dataloader_val=None,
                epoch=None,
                writer=None,
                **kwargs,
                    ):
    def train_epoch(
        self,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        dataloader_train=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the training process for a single epoch with gradient accumulation.
        Args:
@@ -269,7 +303,7 @@
        # Initialize the gradient accumulation
        optim.zero_grad()
        speed_stats = {}
        iterator_stop = torch.tensor(0).to(self.device)
        dataloader_train.batch_sampler.set_epoch(epoch)
@@ -294,6 +328,12 @@
                with maybe_autocast(self.use_fp16):
                    retval = model(**batch)
                    
                    if (
                        self.reset_gpu_cache
                        and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
                    ):
                        torch.cuda.empty_cache()
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                loss, stats, weight = retval
@@ -306,7 +346,7 @@
                    if self.use_ddp or self.use_fsdp:
                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                    # Now weight is summation over all workers
                    loss /= weight.sum() # shape:[1] -> shape:[]
                    loss /= weight.sum()  # shape:[1] -> shape:[]
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
@@ -318,19 +358,26 @@
                    loss.backward()
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
                self.train_loss_avg = (self.train_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                self.train_loss_avg = (
                    self.train_loss_avg * batch_idx + loss.detach().cpu().item()
                ) / (batch_idx + 1)
                if "acc" in stats:
                    self.train_acc_avg = (self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                    self.train_acc_avg = (
                        self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
                    ) / (batch_idx + 1)
                if self.use_ddp or self.use_fsdp:
                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
                        self.device
                    )
                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
                        self.device
                    )
                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
            # Perform an optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % accum_grad == 0:
                # Perform gradient clipping if it is set
@@ -346,7 +393,7 @@
                        )
                        optim.zero_grad()  # Reset gradients
                        continue
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
@@ -361,23 +408,25 @@
                total_time = f"{time.perf_counter() - time5:0.3f}"
                time5 = time.perf_counter()
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
                speed_stats["total_time"] = total_time
                lr = scheduler.get_last_lr()[0]
                batch_num_epoch = 1
                if hasattr(dataloader_train, "__len__"):
                    batch_num_epoch = len(dataloader_train)
                self.log(epoch, batch_idx,
                         batch_num_epoch=batch_num_epoch,
                         lr=lr,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
                         stats=stats,
                         writer=writer,
                         tag="train",
                         data_split_i=kwargs.get("data_split_i", 0),
                         data_split_num=kwargs.get("data_split_num", 1),
                         )
                self.log(
                    epoch,
                    batch_idx,
                    batch_num_epoch=batch_num_epoch,
                    lr=lr,
                    loss=loss.detach().cpu().item(),
                    speed_stats=speed_stats,
                    stats=stats,
                    writer=writer,
                    tag="train",
                    data_split_i=kwargs.get("data_split_i", 0),
                    data_split_num=kwargs.get("data_split_num", 1),
                )
            if (batch_idx + 1) % self.validate_interval == 0:
                self.validate_epoch(
@@ -385,35 +434,41 @@
                    dataloader_val=dataloader_val,
                    epoch=epoch,
                    writer=writer,
                    step=batch_idx+1,
                    step=batch_idx + 1,
                )
            if (batch_idx+1) % self.save_checkpoint_interval == 0:
                self.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler, step=batch_idx+1)
            if (batch_idx + 1) % self.save_checkpoint_interval == 0:
                self.save_checkpoint(
                    epoch,
                    model=model,
                    optim=optim,
                    scheduler=scheduler,
                    scaler=scaler,
                    step=batch_idx + 1,
                )
            time_beg = time.perf_counter()
        else:
            if self.use_ddp or self.use_fsdp:
                iterator_stop.fill_(1)
                dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
            iterator_stop = torch.tensor(0).to(self.device)
    def validate_epoch(self,
                       model=None,
                       dataloader_val=None,
                       epoch=None,
                       writer=None,
                       **kwargs,
                       ):
    def validate_epoch(
        self,
        model=None,
        dataloader_val=None,
        epoch=None,
        writer=None,
        **kwargs,
    ):
        """
        Defines the validation process for a single epoch.
        Should be implemented with the actual model validation steps.
        Args:
            epoch (int): The current epoch number.
        """
@@ -421,9 +476,9 @@
            dist.barrier()
        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
        model.eval()
        with torch.no_grad():
            speed_stats = {}
            time5 = time.perf_counter()
            iterator_stop = torch.tensor(0).to(self.device)
@@ -450,7 +505,7 @@
                    if self.use_ddp or self.use_fsdp:
                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                    # Now weight is summation over all workers
                    loss /= weight.sum() # shape:[1] -> shape:[]
                    loss /= weight.sum()  # shape:[1] -> shape:[]
                    # Multiply world_size because DistributedDataParallel
                    # automatically normalizes the gradient by world_size.
                    loss *= self.world_size
@@ -458,12 +513,20 @@
                loss = loss
                time4 = time.perf_counter()
                self.val_loss_avg = (self.val_loss_avg*batch_idx + loss.detach().cpu().item())/(batch_idx+1)
                self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
                    batch_idx + 1
                )
                if "acc" in stats:
                    self.val_acc_avg = (self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()) / (batch_idx + 1)
                    self.val_acc_avg = (
                        self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
                    ) / (batch_idx + 1)
                if self.use_ddp or self.use_fsdp:
                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
                        self.device
                    )
                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
                        self.device
                    )
                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
@@ -472,15 +535,17 @@
                batch_num_epoch = 1
                if hasattr(dataloader_val, "__len__"):
                    batch_num_epoch = len(dataloader_val)
                self.log(epoch, batch_idx,
                         batch_num_epoch=batch_num_epoch,
                         lr=0.0,
                         loss=loss.detach().cpu().item(),
                         speed_stats=speed_stats,
                         stats=stats,
                         writer=writer,
                         tag="val",
                         )
                self.log(
                    epoch,
                    batch_idx,
                    batch_num_epoch=batch_num_epoch,
                    lr=0.0,
                    loss=loss.detach().cpu().item(),
                    speed_stats=speed_stats,
                    stats=stats,
                    writer=writer,
                    tag="val",
                )
            else:
                if self.use_ddp or self.use_fsdp:
@@ -488,7 +553,7 @@
                    dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
        if kwargs.get("step", None) is None:
            ckpt_name = f'model.pt.ep{epoch}'
            ckpt_name = f"model.pt.ep{epoch}"
        else:
            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
        self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
@@ -498,34 +563,37 @@
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
            iterator_stop = torch.tensor(0).to(self.device)
    def log(self,
            epoch=0,
            batch_idx=0,
            batch_num_epoch=-1,
            lr=0.0,
            loss=0.0,
            speed_stats=None,
            stats=None,
            writer=None,
            tag="train",
            data_split_i=0,
            data_split_num=1,
            **kwargs,
            ):
    def log(
        self,
        epoch=0,
        batch_idx=0,
        batch_num_epoch=-1,
        lr=0.0,
        loss=0.0,
        speed_stats=None,
        stats=None,
        writer=None,
        tag="train",
        data_split_i=0,
        data_split_num=1,
        **kwargs,
    ):
        if (batch_idx + 1) % self.log_interval == 0:
            gpu_info = "GPU, memory: usage: {:.3f} GB, " \
                       "peak: {:.3f} GB, " \
                       "cache: {:.3f} GB, " \
                       "cache_peak: {:.3f} GB".format(torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
                                          torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
                                          torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
                                          )
            gpu_info = (
                "GPU, memory: usage: {:.3f} GB, "
                "peak: {:.3f} GB, "
                "cache: {:.3f} GB, "
                "cache_peak: {:.3f} GB".format(
                    torch.cuda.memory_allocated() / 1024 / 1024 / 1024,
                    torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024,
                    torch.cuda.memory_reserved() / 1024 / 1024 / 1024,
                    torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024,
                )
            )
            loss_avg_epoch = getattr(self, f"{tag}_loss_avg")
            acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
            description = (
@@ -544,23 +612,27 @@
                f"{gpu_info}"
            )
            logging.info(description)
            if writer is not None:
                writer.add_scalar(f'rank{self.local_rank}_loss/{tag}', loss, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                writer.add_scalar(f'rank{self.local_rank}_lr/{tag}', lr, self.batch_total)
                writer.add_scalar(f"rank{self.local_rank}_loss/{tag}", loss, self.batch_total)
                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
                for key, var in stats.items():
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', var.item(), self.batch_total)
                    writer.add_scalar(
                        f"stats_rank{self.local_rank}_{key}/{tag}", var.item(), self.batch_total
                    )
                for key, var in speed_stats.items():
                    writer.add_scalar(f'stats_rank{self.local_rank}_{key}/{tag}', eval(var), self.batch_total)
                    writer.add_scalar(
                        f"stats_rank{self.local_rank}_{key}/{tag}", eval(var), self.batch_total
                    )
    def close(self, writer=None):
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        if writer is not None:
            writer.close()
        if self.use_ddp or self.use_fsdp:
            torch.distributed.destroy_process_group()
            torch.distributed.destroy_process_group()