From c0008fd46134d60a3a41b022bf9156cea5b145e5 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 十二月 2023 10:10:40 +0800
Subject: [PATCH] Merge branch 'dev_gzf_funasr2' into main
---
funasr/cli/trainer.py | 199 +++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 199 insertions(+), 0 deletions(-)
diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
new file mode 100644
index 0000000..28a843b
--- /dev/null
+++ b/funasr/cli/trainer.py
@@ -0,0 +1,199 @@
+import torch
+import os
+from funasr.torch_utils.device_funcs import to_device
+import logging
+from tqdm import tqdm
+from contextlib import nullcontext
+import torch.distributed as dist
+from funasr.torch_utils.recursive_op import recursive_average
+
+class Trainer:
+ """
+ A simple trainer class for training a PyTorch model, saving checkpoints at the end of each epoch,
+ and optionally resuming from a saved checkpoint.
+
+ Attributes:
+ max_epoch (int): Maximum number of epochs for training.
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): DataLoader for the validation dataset.
+ output_dir (str): Directory where model checkpoints will be saved.
+ resume (str, optional): Path to a checkpoint to resume training from.
+ """
+
+ def __init__(self, model,
+ optim,
+ scheduler,
+ dataloader_train,
+ dataloader_val,
+ local_rank,
+ use_ddp=False,
+ use_fsdp=False,
+ **kwargs):
+ """
+ Initializes the Trainer class with the model, optimizer, scheduler, dataloaders, and other settings.
+
+ Args:
+ model (torch.nn.Module): The model to be trained.
+ optim (torch.optim.Optimizer): The optimizer to use for training.
+ scheduler (torch.optim.lr_scheduler._LRScheduler): The learning rate scheduler.
+ dataloader_train (torch.utils.data.DataLoader): The DataLoader for the training dataset.
+ dataloader_val (torch.utils.data.DataLoader): The DataLoader for the validation dataset.
+ **kwargs: Additional keyword arguments:
+ max_epoch (int): The maximum number of epochs for training.
+ output_dir (str): The directory where model checkpoints will be saved. Default is './'.
+ resume (str, optional): The file path to a checkpoint to resume training from.
+ """
+
+ self.model = model
+ self.optim = optim
+ self.scheduler = scheduler
+ self.dataloader_train = dataloader_train
+ self.dataloader_val = dataloader_val
+ self.output_dir = kwargs.get('output_dir', './')
+ self.resume = kwargs.get('resume', None)
+ self.start_epoch = 1
+ self.max_epoch = kwargs.get('max_epoch', 100)
+ self.local_rank = local_rank
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ self.use_ddp = use_ddp
+ self.use_fsdp = use_fsdp
+ self.device = torch.device("cuda", local_rank)
+ self.kwargs = kwargs
+
+ if self.resume:
+ self._resume_checkpoint(self.resume)
+
+ def _save_checkpoint(self, epoch):
+ """
+ Saves a checkpoint containing the model's state, the optimizer's state,
+ and the scheduler's state at the end of the given epoch. This method is
+ intended to be called at the end of each epoch to save the training progress.
+
+ Args:
+ epoch (int): The epoch number at which the checkpoint is being saved.
+ """
+ state = {
+ 'epoch': epoch,
+ 'state_dict': self.model.state_dict(),
+ 'optimizer': self.optim.state_dict(),
+ 'scheduler': self.scheduler.state_dict(),
+ }
+ # Create output directory if it does not exist
+ os.makedirs(self.output_dir, exist_ok=True)
+ filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
+ torch.save(state, filename)
+ print(f'Checkpoint saved to {filename}')
+
+ def _resume_checkpoint(self, resume_path):
+ """
+ Resumes training from a checkpoint at the given file path.
+ Loads the model's state, the optimizer's state, and the scheduler's state.
+
+ Args:
+ resume_path (str): The file path to the checkpoint to resume from.
+ """
+ if os.path.isfile(resume_path):
+ checkpoint = torch.load(resume_path)
+ self.start_epoch = checkpoint['epoch'] + 1
+ self.model.load_state_dict(checkpoint['state_dict'])
+ self.optim.load_state_dict(checkpoint['optimizer'])
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
+ print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})")
+ else:
+ print(f"No checkpoint found at '{resume_path}', starting from scratch")
+
+ def run(self):
+ """
+ Starts the training process, iterating over epochs, training the model,
+ and saving checkpoints at the end of each epoch.
+ """
+ for epoch in range(self.start_epoch, self.max_epoch + 1):
+ self._train_epoch(epoch)
+ # self._validate_epoch(epoch)
+ if dist.get_rank() == 0:
+ self._save_checkpoint(epoch)
+ self.scheduler.step()
+
+ def _train_epoch(self, epoch):
+ """
+ Defines the training process for a single epoch with gradient accumulation.
+ Args:
+ epoch (int): The current epoch number.
+ """
+ self.model.train()
+ pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train),
+ dynamic_ncols=True)
+
+ # Set the number of steps for gradient accumulation
+ accum_grad = self.kwargs.get("accum_grad", 1)
+ # Initialize the gradient accumulation
+ self.optim.zero_grad()
+
+ for batch_idx, batch in enumerate(self.dataloader_train):
+ batch = to_device(batch, self.device)
+
+ my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
+ with my_context():
+ retval = self.model(**batch)
+ loss, stats, weight = retval
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if self.use_ddp or self.use_fsdp:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed=True)
+ # Now weight is summation over all workers
+ loss /= weight
+ # Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= self.world_size
+ # Scale the loss since we're not updating for every mini-batch
+ loss = loss / accum_grad
+ loss.backward()
+
+ # Perform an optimizer step only after accumulating enough gradients
+ if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
+ # Perform gradient clipping if it is set
+ if self.kwargs.get("grad_clip", None) is not None:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self.model.parameters(),
+ max_norm=self.kwargs.get("grad_clip", 10.0),
+ norm_type=self.kwargs.get("grad_clip_type", 2.0),
+ )
+ if not torch.isfinite(grad_norm):
+ logging.warning(
+ f"The grad norm is {grad_norm}. Skipping updating the model."
+ )
+ self.optim.zero_grad() # Reset gradients
+ continue
+
+ # Execute an optimization step (update model parameters)
+ self.optim.step()
+ self.scheduler.step()
+ # Clear gradients for the next accumulation stage
+ self.optim.zero_grad()
+
+ pbar.update(1)
+ if self.local_rank == 0:
+ pbar.set_description(
+ f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)} (loss: {loss.detach().float():.3f}, {[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]})")
+
+ pbar.close()
+
+ def _validate_epoch(self, epoch):
+ """
+ Defines the validation process for a single epoch.
+ Should be implemented with the actual model validation steps.
+
+ Args:
+ epoch (int): The current epoch number.
+ """
+ self.model.eval()
+ with torch.no_grad():
+ for data, target in self.dataloader_val:
+ # Implement the model validation steps here
+ pass
--
Gitblit v1.9.1