游雁
2023-12-06 15868f623089cf70983a8b4f435ff86e7f160b8a
funasr/cli/trainer.py
@@ -5,6 +5,7 @@
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
from funasr.torch_utils.recursive_op import recursive_average
class Trainer:
   """
@@ -56,6 +57,8 @@
      self.start_epoch = 1
      self.max_epoch = kwargs.get('max_epoch', 100)
      self.local_rank = local_rank
      self.rank = dist.get_rank()
      self.world_size = dist.get_world_size()
      self.use_ddp = use_ddp
      self.use_fsdp = use_fsdp
      self.device = torch.device("cuda", local_rank)
@@ -113,7 +116,7 @@
         # self._validate_epoch(epoch)
         if dist.get_rank() == 0:
            self._save_checkpoint(epoch)
         # self.scheduler.step()
         self.scheduler.step()
   
   def _train_epoch(self, epoch):
      """
@@ -126,24 +129,34 @@
                  dynamic_ncols=True)
      
      # Set the number of steps for gradient accumulation
      accumulation_steps = self.kwargs.get("accumulation_steps", 1)
      accum_grad = self.kwargs.get("accum_grad", 1)
      # Initialize the gradient accumulation
      self.optim.zero_grad()
      
      for batch_idx, batch in enumerate(self.dataloader_train):
         batch = to_device(batch, self.device)
         
         my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
         my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
         with my_context():
            retval = self.model(**batch)
            loss, stats, weight = retval
            stats = {k: v for k, v in stats.items() if v is not None}
            if self.use_ddp or self.use_fsdp:
               # Apply weighted averaging for loss and stats
               loss = (loss * weight.type(loss.dtype)).sum()
               # if distributed, this method can also apply all_reduce()
               stats, weight = recursive_average(stats, weight, distributed=True)
               # Now weight is summation over all workers
               loss /= weight
               # Multiply world_size because DistributedDataParallel
               # automatically normalizes the gradient by world_size.
               loss *= self.world_size
            # Scale the loss since we're not updating for every mini-batch
            loss = loss / accumulation_steps
            loss = loss / accum_grad
            loss.backward()
         
         # Perform an optimizer step only after accumulating enough gradients
         if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.dataloader_train):
         if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
            # Perform gradient clipping if it is set
            if self.kwargs.get("grad_clip", None) is not None:
               grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -170,43 +183,6 @@
               f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
         
      pbar.close()
   # def _train_epoch(self, epoch):
   #    """
   #    Defines the training process for a single epoch.
   #    Should be implemented with the actual model training steps.
   #
   #    Args:
   #       epoch (int): The current epoch number.
   #    """
   #    self.model.train()
   #    pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), dynamic_ncols=True)
   #    for batch_idx, batch in enumerate(self.dataloader_train):
   #       batch = to_device(batch, "cpu")
   #       retval = self.model(**batch)
   #       loss, stats, weight = retval
   #       self.optim.zero_grad()
   #       loss.backward()
   #
   #       # compute the gradient norm to check if it is normal or not
   #       grad_norm = torch.nn.utils.clip_grad_norm_(
   #          self.model.parameters(),
   #          max_norm=self.kwargs.get("grad_clip", 10.0),
   #          norm_type=self.kwargs.get("grad_clip_type", 2.0),
   #       )
   #       if not torch.isfinite(grad_norm):
   #          logging.warning(
   #             f"The grad norm is {grad_norm}. Skipping updating the model."
   #          )
   #          continue
   #       self.optim.step()
   #       self.scheduler.step()
   #       pbar.update(1)
   #       pbar.set_description(
   #          f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
   #
   #    pbar.close()
   #
   def _validate_epoch(self, epoch):
      """
@@ -221,19 +197,3 @@
         for data, target in self.dataloader_val:
            # Implement the model validation steps here
            pass
# # Example usage
# if __name__ == "__main__":
#    # Assuming the following objects have already been correctly created and initialized:
#    # model, optim, scheduler, dataloader_train, and dataloader_val.
#    trainer = Trainer(
#        max_epoch=10,
#        model=model,
#        optim=optim,
#        scheduler=scheduler,
#        dataloader_train=dataloader_train,
#        dataloader_val=dataloader_val,
#        output_dir='path_to_save_model',
#        resume='path_to_checkpoint_if_any'
#    )
#    trainer.run()