游雁
2023-12-06 1c2eb051cdcc6890af9ba64b10b9a0152288469a
funasr/cli/trainer.py
@@ -4,6 +4,7 @@
import logging
from tqdm import tqdm
from contextlib import nullcontext
import torch.distributed as dist
class Trainer:
   """
@@ -80,7 +81,7 @@
      }
      # Create output directory if it does not exist
      os.makedirs(self.output_dir, exist_ok=True)
      filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
      filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
      torch.save(state, filename)
      print(f'Checkpoint saved to {filename}')
   
@@ -110,8 +111,9 @@
      for epoch in range(self.start_epoch, self.max_epoch + 1):
         self._train_epoch(epoch)
         # self._validate_epoch(epoch)
         self._save_checkpoint(epoch)
         self.scheduler.step()
         if dist.get_rank() == 0:
            self._save_checkpoint(epoch)
         # self.scheduler.step()
   
   def _train_epoch(self, epoch):
      """
@@ -131,7 +133,7 @@
      for batch_idx, batch in enumerate(self.dataloader_train):
         batch = to_device(batch, self.device)
         
         my_context = model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
         my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
         with my_context():
            retval = self.model(**batch)
            loss, stats, weight = retval
@@ -163,9 +165,10 @@
            self.optim.zero_grad()
         
         pbar.update(1)
         pbar.set_description(
            f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
         if self.local_rank == 0:
            pbar.set_description(
               f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
      pbar.close()
   
   # def _train_epoch(self, epoch):