游雁
2024-03-18 cbe2ea7e07cbf364827bd89cefc42b3f643ea3be
funasr/train_utils/trainer.py
@@ -5,7 +5,8 @@
from tqdm import tqdm
from datetime import datetime
import torch.distributed as dist
from contextlib import nullcontext
from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext, contextmanager
# from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
from pathlib import Path
@@ -13,6 +14,15 @@
from funasr.train_utils.device_funcs import to_device
from funasr.train_utils.recursive_op import recursive_average
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
@contextmanager
def maybe_autocast(enabled):
    if enabled:
        with autocast():
            yield
    else:
        yield
class Trainer:
    """
@@ -36,8 +46,9 @@
                 dataloader_train,
                 dataloader_val,
                 local_rank,
                 use_ddp=False,
                 use_fsdp=False,
                 use_ddp: bool = False,
                 use_fsdp: bool = False,
                 use_fp16: bool = False,
                 output_dir: str="./",
                 **kwargs):
        """
@@ -72,6 +83,12 @@
        self.kwargs = kwargs
        self.log_interval = kwargs.get("log_interval", 50)
        self.batch_total = 0
        self.use_fp16 = use_fp16
        self.disable_gpu_cache = kwargs.get("disable_gpu_cache", True)
        scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
        scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
        self.scaler = scaler
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        
    
        try:
@@ -88,7 +105,7 @@
        self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
        
    
    def _save_checkpoint(self, epoch):
    def _save_checkpoint(self, epoch, step=None):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
@@ -103,9 +120,15 @@
            'optimizer': self.optim.state_dict(),
            'scheduler': self.scheduler.state_dict(),
        }
        if self.scaler:
            state["scaler_state"] = self.scaler.state_dict()
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
        if step is None:
            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
        else:
            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
        torch.save(state, filename)
        
        print(f'\nCheckpoint saved to {filename}\n')
@@ -123,7 +146,7 @@
        """
        ckpt = os.path.join(resume_path, "model.pt")
        if os.path.isfile(ckpt):
            checkpoint = torch.load(ckpt)
            checkpoint = torch.load(ckpt, map_location="cpu")
            self.start_epoch = checkpoint['epoch'] + 1
            # self.model.load_state_dict(checkpoint['state_dict'])
            src_state = checkpoint['state_dict']
@@ -141,10 +164,13 @@
            self.model.load_state_dict(dst_state)
            self.optim.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            if self.scaler and 'scaler_state' in checkpoint:
                self.scaler.load_state_dict(checkpoint['scaler_state'])
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{ckpt}', starting from scratch")
            print(f"No checkpoint found at '{ckpt}', does not resume status!")
        self.model.to(self.device)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        
@@ -221,9 +247,10 @@
            my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
            with my_context():
                time2 = time.perf_counter()
                retval = self.model(**batch)
                torch.cuda.empty_cache()
                with maybe_autocast(self.use_fp16):
                    retval = self.model(**batch)
                if self.disable_gpu_cache: torch.cuda.empty_cache()
                time3 = time.perf_counter()
                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
@@ -241,7 +268,10 @@
                    loss *= self.world_size
                # Scale the loss since we're not updating for every mini-batch
                loss = loss / accum_grad
                loss.backward()
                if self.use_fp16:
                    self.scaler.scale(loss).backward()
                else:
                    loss.backward()
                time4 = time.perf_counter()
                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
            
@@ -264,10 +294,14 @@
                # Execute an optimization step (update model parameters)
                if self.use_ddp or self.use_fsdp:
                    dist.barrier()
                self.optim.step()
                if self.use_fp16:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()
                self.scheduler.step()
                # Clear gradients for the next accumulation stage
                self.optim.zero_grad()
                self.optim.zero_grad(set_to_none=True)
                total_time = f"{time.perf_counter() - time5:0.3f}"
                time5 = time.perf_counter()
                speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
@@ -309,8 +343,10 @@
                    for key, var in speed_stats.items():
                        self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
                self._save_checkpoint(epoch, step=batch_idx+1)
        pbar.close()
    def _validate_epoch(self, epoch):
        """
@@ -373,4 +409,6 @@
                                                   epoch * len(self.dataloader_val) + batch_idx)
                        for key, var in speed_stats.items():
                            self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
                                                   epoch * len(self.dataloader_val) + batch_idx)
                                                   epoch * len(self.dataloader_val) + batch_idx)
        self.model.train()