From d77910eb6d171727f2350e45c31c91436c4c8891 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 十二月 2023 13:42:40 +0800
Subject: [PATCH] funasr2

---
 funasr/cli/trainer.py |   48 +++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
index 28a843b..0178767 100644
--- a/funasr/cli/trainer.py
+++ b/funasr/cli/trainer.py
@@ -2,6 +2,7 @@
 import os
 from funasr.torch_utils.device_funcs import to_device
 import logging
+import time
 from tqdm import tqdm
 from contextlib import nullcontext
 import torch.distributed as dist
@@ -53,19 +54,27 @@
 		self.dataloader_train = dataloader_train
 		self.dataloader_val = dataloader_val
 		self.output_dir = kwargs.get('output_dir', './')
-		self.resume = kwargs.get('resume', None)
+		self.resume = kwargs.get('resume', True)
 		self.start_epoch = 1
 		self.max_epoch = kwargs.get('max_epoch', 100)
 		self.local_rank = local_rank
-		self.rank = dist.get_rank()
-		self.world_size = dist.get_world_size()
 		self.use_ddp = use_ddp
 		self.use_fsdp = use_fsdp
-		self.device = torch.device("cuda", local_rank)
+		self.device = next(model.parameters()).device
 		self.kwargs = kwargs
 		
 		if self.resume:
 			self._resume_checkpoint(self.resume)
+	
+		try:
+			rank = dist.get_rank()
+			world_size = dist.get_world_size()
+		except:
+			rank = 0
+			world_size = 1
+			logging.warning("distributed is not initialized, only single shard")
+		self.rank = rank
+		self.world_size = world_size
 	
 	def _save_checkpoint(self, epoch):
 		"""
@@ -117,6 +126,7 @@
 			if dist.get_rank() == 0:
 				self._save_checkpoint(epoch)
 			self.scheduler.step()
+			break
 	
 	def _train_epoch(self, epoch):
 		"""
@@ -132,13 +142,21 @@
 		accum_grad = self.kwargs.get("accum_grad", 1)
 		# Initialize the gradient accumulation
 		self.optim.zero_grad()
-		
+		speed_stats = {}
+		time5 = time.perf_counter()
 		for batch_idx, batch in enumerate(self.dataloader_train):
+			time1 = time.perf_counter()
+			speed_stats["data_load"] = f"{time1-time5:0.3f}"
+			# import pdb;
+			# pdb.set_trace()
 			batch = to_device(batch, self.device)
 			
 			my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
 			with my_context():
+				time2 = time.perf_counter()
 				retval = self.model(**batch)
+				time3 = time.perf_counter()
+				speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
 				loss, stats, weight = retval
 				stats = {k: v for k, v in stats.items() if v is not None}
 				if self.use_ddp or self.use_fsdp:
@@ -154,6 +172,8 @@
 				# Scale the loss since we're not updating for every mini-batch
 				loss = loss / accum_grad
 				loss.backward()
+				time4 = time.perf_counter()
+				speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
 			
 			# Perform an optimizer step only after accumulating enough gradients
 			if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
@@ -176,12 +196,26 @@
 				self.scheduler.step()
 				# Clear gradients for the next accumulation stage
 				self.optim.zero_grad()
+				total_time = f"{time.perf_counter() - time5:0.3f}"
+				time5 = time.perf_counter()
+				speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
+	
+				speed_stats["total_time"] = total_time
+				
 			
 			pbar.update(1)
 			if self.local_rank == 0:
-				pbar.set_description(
-					f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float():.3f}, {[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]})")
+				description = (
+					f"Epoch: {epoch + 1}/{self.max_epoch}, "
+					f"step {batch_idx}/{len(self.dataloader_train)}, "
+					f"{speed_stats}, "
+					f"(loss: {loss.detach().float():.3f}), "
+					f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+				)
+				pbar.set_description(description)
 			
+			if batch_idx == 2:
+				break
 		pbar.close()
 
 	def _validate_epoch(self, epoch):

--
Gitblit v1.9.1