From 1c2eb051cdcc6890af9ba64b10b9a0152288469a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 06 十二月 2023 19:45:49 +0800
Subject: [PATCH] funasr2

---
 funasr/cli/trainer.py |   17 ++++++++++-------
 1 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
index 30e0419..ee5af0f 100644
--- a/funasr/cli/trainer.py
+++ b/funasr/cli/trainer.py
@@ -4,6 +4,7 @@
 import logging
 from tqdm import tqdm
 from contextlib import nullcontext
+import torch.distributed as dist
 
 class Trainer:
 	"""
@@ -80,7 +81,7 @@
 		}
 		# Create output directory if it does not exist
 		os.makedirs(self.output_dir, exist_ok=True)
-		filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
+		filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
 		torch.save(state, filename)
 		print(f'Checkpoint saved to {filename}')
 	
@@ -110,8 +111,9 @@
 		for epoch in range(self.start_epoch, self.max_epoch + 1):
 			self._train_epoch(epoch)
 			# self._validate_epoch(epoch)
-			self._save_checkpoint(epoch)
-			self.scheduler.step()
+			if dist.get_rank() == 0:
+				self._save_checkpoint(epoch)
+			# self.scheduler.step()
 	
 	def _train_epoch(self, epoch):
 		"""
@@ -131,7 +133,7 @@
 		for batch_idx, batch in enumerate(self.dataloader_train):
 			batch = to_device(batch, self.device)
 			
-			my_context = model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
+			my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
 			with my_context():
 				retval = self.model(**batch)
 				loss, stats, weight = retval
@@ -163,9 +165,10 @@
 				self.optim.zero_grad()
 			
 			pbar.update(1)
-			pbar.set_description(
-				f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
-		
+			if self.local_rank == 0:
+				pbar.set_description(
+					f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
+			
 		pbar.close()
 	
 	# def _train_epoch(self, epoch):

--
Gitblit v1.9.1