From 806a03609df033d61f824f1ab8527eb88fe837ad Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 19:43:13 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer

---
 funasr/cli/trainer.py |  129 +++++++++++++++++++++----------------------
 1 files changed, 63 insertions(+), 66 deletions(-)

diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
index 30e0419..0178767 100644
--- a/funasr/cli/trainer.py
+++ b/funasr/cli/trainer.py
@@ -2,8 +2,11 @@
 import os
 from funasr.torch_utils.device_funcs import to_device
 import logging
+import time
 from tqdm import tqdm
 from contextlib import nullcontext
+import torch.distributed as dist
+from funasr.torch_utils.recursive_op import recursive_average
 
 class Trainer:
 	"""
@@ -51,17 +54,27 @@
 		self.dataloader_train = dataloader_train
 		self.dataloader_val = dataloader_val
 		self.output_dir = kwargs.get('output_dir', './')
-		self.resume = kwargs.get('resume', None)
+		self.resume = kwargs.get('resume', True)
 		self.start_epoch = 1
 		self.max_epoch = kwargs.get('max_epoch', 100)
 		self.local_rank = local_rank
 		self.use_ddp = use_ddp
 		self.use_fsdp = use_fsdp
-		self.device = torch.device("cuda", local_rank)
+		self.device = next(model.parameters()).device
 		self.kwargs = kwargs
 		
 		if self.resume:
 			self._resume_checkpoint(self.resume)
+	
+		try:
+			rank = dist.get_rank()
+			world_size = dist.get_world_size()
+		except:
+			rank = 0
+			world_size = 1
+			logging.warning("distributed is not initialized, only single shard")
+		self.rank = rank
+		self.world_size = world_size
 	
 	def _save_checkpoint(self, epoch):
 		"""
@@ -80,7 +93,7 @@
 		}
 		# Create output directory if it does not exist
 		os.makedirs(self.output_dir, exist_ok=True)
-		filename = os.path.join(self.output_dir, f'model.{epoch}.pb')
+		filename = os.path.join(self.output_dir, f'model.e{epoch}.pb')
 		torch.save(state, filename)
 		print(f'Checkpoint saved to {filename}')
 	
@@ -110,8 +123,10 @@
 		for epoch in range(self.start_epoch, self.max_epoch + 1):
 			self._train_epoch(epoch)
 			# self._validate_epoch(epoch)
-			self._save_checkpoint(epoch)
+			if dist.get_rank() == 0:
+				self._save_checkpoint(epoch)
 			self.scheduler.step()
+			break
 	
 	def _train_epoch(self, epoch):
 		"""
@@ -124,24 +139,44 @@
 		            dynamic_ncols=True)
 		
 		# Set the number of steps for gradient accumulation
-		accumulation_steps = self.kwargs.get("accumulation_steps", 1)
+		accum_grad = self.kwargs.get("accum_grad", 1)
 		# Initialize the gradient accumulation
 		self.optim.zero_grad()
-		
+		speed_stats = {}
+		time5 = time.perf_counter()
 		for batch_idx, batch in enumerate(self.dataloader_train):
+			time1 = time.perf_counter()
+			speed_stats["data_load"] = f"{time1-time5:0.3f}"
+			# import pdb;
+			# pdb.set_trace()
 			batch = to_device(batch, self.device)
 			
-			my_context = model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
+			my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
 			with my_context():
+				time2 = time.perf_counter()
 				retval = self.model(**batch)
+				time3 = time.perf_counter()
+				speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
 				loss, stats, weight = retval
-				
+				stats = {k: v for k, v in stats.items() if v is not None}
+				if self.use_ddp or self.use_fsdp:
+					# Apply weighted averaging for loss and stats
+					loss = (loss * weight.type(loss.dtype)).sum()
+					# if distributed, this method can also apply all_reduce()
+					stats, weight = recursive_average(stats, weight, distributed=True)
+					# Now weight is summation over all workers
+					loss /= weight
+					# Multiply world_size because DistributedDataParallel
+					# automatically normalizes the gradient by world_size.
+					loss *= self.world_size
 				# Scale the loss since we're not updating for every mini-batch
-				loss = loss / accumulation_steps
+				loss = loss / accum_grad
 				loss.backward()
+				time4 = time.perf_counter()
+				speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
 			
 			# Perform an optimizer step only after accumulating enough gradients
-			if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.dataloader_train):
+			if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
 				# Perform gradient clipping if it is set
 				if self.kwargs.get("grad_clip", None) is not None:
 					grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -161,49 +196,27 @@
 				self.scheduler.step()
 				# Clear gradients for the next accumulation stage
 				self.optim.zero_grad()
+				total_time = f"{time.perf_counter() - time5:0.3f}"
+				time5 = time.perf_counter()
+				speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
+	
+				speed_stats["total_time"] = total_time
+				
 			
 			pbar.update(1)
-			pbar.set_description(
-				f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
-		
+			if self.local_rank == 0:
+				description = (
+					f"Epoch: {epoch + 1}/{self.max_epoch}, "
+					f"step {batch_idx}/{len(self.dataloader_train)}, "
+					f"{speed_stats}, "
+					f"(loss: {loss.detach().float():.3f}), "
+					f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
+				)
+				pbar.set_description(description)
+			
+			if batch_idx == 2:
+				break
 		pbar.close()
-	
-	# def _train_epoch(self, epoch):
-	# 	"""
-	# 	Defines the training process for a single epoch.
-	# 	Should be implemented with the actual model training steps.
-	#
-	# 	Args:
-	# 		epoch (int): The current epoch number.
-	# 	"""
-	# 	self.model.train()
-	# 	pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), dynamic_ncols=True)
-	# 	for batch_idx, batch in enumerate(self.dataloader_train):
-	# 		batch = to_device(batch, "cpu")
-	# 		retval = self.model(**batch)
-	# 		loss, stats, weight = retval
-	# 		self.optim.zero_grad()
-	# 		loss.backward()
-	#
-	# 		# compute the gradient norm to check if it is normal or not
-	# 		grad_norm = torch.nn.utils.clip_grad_norm_(
-	# 			self.model.parameters(),
-	# 			max_norm=self.kwargs.get("grad_clip", 10.0),
-	# 			norm_type=self.kwargs.get("grad_clip_type", 2.0),
-	# 		)
-	# 		if not torch.isfinite(grad_norm):
-	# 			logging.warning(
-	# 				f"The grad norm is {grad_norm}. Skipping updating the model."
-	# 			)
-	# 			continue
-	# 		self.optim.step()
-	# 		self.scheduler.step()
-	# 		pbar.update(1)
-	# 		pbar.set_description(
-	# 			f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
-	#
-	# 	pbar.close()
-	#
 
 	def _validate_epoch(self, epoch):
 		"""
@@ -218,19 +231,3 @@
 			for data, target in self.dataloader_val:
 				# Implement the model validation steps here
 				pass
-
-# # Example usage
-# if __name__ == "__main__":
-# 	# Assuming the following objects have already been correctly created and initialized:
-# 	# model, optim, scheduler, dataloader_train, and dataloader_val.
-# 	trainer = Trainer(
-# 	    max_epoch=10,
-# 	    model=model,
-# 	    optim=optim,
-# 	    scheduler=scheduler,
-# 	    dataloader_train=dataloader_train,
-# 	    dataloader_val=dataloader_val,
-# 	    output_dir='path_to_save_model',
-# 	    resume='path_to_checkpoint_if_any'
-# 	)
-# 	trainer.run()
\ No newline at end of file

--
Gitblit v1.9.1