From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2

---
 funasr/train_utils/trainer.py |   13 +++++++------
 1 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 425b79f..ea502f7 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -55,7 +55,7 @@
 		self.dataloader_val = dataloader_val
 		self.output_dir = kwargs.get('output_dir', './')
 		self.resume = kwargs.get('resume', True)
-		self.start_epoch = 1
+		self.start_epoch = 0
 		self.max_epoch = kwargs.get('max_epoch', 100)
 		self.local_rank = local_rank
 		self.use_ddp = use_ddp
@@ -123,7 +123,7 @@
 		for epoch in range(self.start_epoch, self.max_epoch + 1):
 			self._train_epoch(epoch)
 			# self._validate_epoch(epoch)
-			if dist.get_rank() == 0:
+			if self.rank == 0:
 				self._save_checkpoint(epoch)
 			self.scheduler.step()
 			break
@@ -201,21 +201,22 @@
 				speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
 	
 				speed_stats["total_time"] = total_time
-				
 			
+			# import pdb;
+			# pdb.set_trace()
 			pbar.update(1)
 			if self.local_rank == 0:
 				description = (
 					f"Epoch: {epoch + 1}/{self.max_epoch}, "
 					f"step {batch_idx}/{len(self.dataloader_train)}, "
 					f"{speed_stats}, "
-					f"(loss: {loss.detach().float():.3f}), "
+					f"(loss: {loss.detach().cpu().item():.3f}), "
 					f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
 				)
 				pbar.set_description(description)
 			
-			if batch_idx == 2:
-				break
+			# if batch_idx == 2:
+			# 	break
 		pbar.close()
 
 	def _validate_epoch(self, epoch):

--
Gitblit v1.9.1