From d29f201e3201bde6a984e436888a2aae877e449f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 三月 2024 12:04:50 +0800
Subject: [PATCH] vad conf

---
 funasr/train_utils/trainer.py |   24 +++++++++++++++++-------
 1 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 3b20596..14abd6c 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -88,6 +88,7 @@
         scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
         scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
         self.scaler = scaler
+        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
         
     
         try:
@@ -104,7 +105,7 @@
         self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
         
     
-    def _save_checkpoint(self, epoch):
+    def _save_checkpoint(self, epoch, step=None):
         """
         Saves a checkpoint containing the model's state, the optimizer's state,
         and the scheduler's state at the end of the given epoch. This method is
@@ -123,7 +124,11 @@
             state["scaler_state"] = self.scaler.state_dict()
         # Create output directory if it does not exist
         os.makedirs(self.output_dir, exist_ok=True)
-        filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+        if step is None:
+            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+        else:
+            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
+        
         torch.save(state, filename)
         
         print(f'\nCheckpoint saved to {filename}\n')
@@ -141,7 +146,7 @@
         """
         ckpt = os.path.join(resume_path, "model.pt")
         if os.path.isfile(ckpt):
-            checkpoint = torch.load(ckpt)
+            checkpoint = torch.load(ckpt, map_location="cpu")
             self.start_epoch = checkpoint['epoch'] + 1
             # self.model.load_state_dict(checkpoint['state_dict'])
             src_state = checkpoint['state_dict']
@@ -163,8 +168,9 @@
                 self.scaler.load_state_dict(checkpoint['scaler_state'])
             print(f"Checkpoint loaded successfully from '{ckpt}'")
         else:
-            print(f"No checkpoint found at '{ckpt}', starting from scratch")
-
+            print(f"No checkpoint found at '{ckpt}', does not resume status!")
+        
+        self.model.to(self.device)
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
         
@@ -337,8 +343,10 @@
                     for key, var in speed_stats.items():
                         self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
 
-
+            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+                self._save_checkpoint(epoch, step=batch_idx+1)
         pbar.close()
+        
 
     def _validate_epoch(self, epoch):
         """
@@ -401,4 +409,6 @@
                                                    epoch * len(self.dataloader_val) + batch_idx)
                         for key, var in speed_stats.items():
                             self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
-                                                   epoch * len(self.dataloader_val) + batch_idx)
\ No newline at end of file
+                                                   epoch * len(self.dataloader_val) + batch_idx)
+
+        self.model.train()
\ No newline at end of file

--
Gitblit v1.9.1