From 15868f623089cf70983a8b4f435ff86e7f160b8a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 06 十二月 2023 23:50:54 +0800
Subject: [PATCH] funasr2

---
 funasr/cli/train_cli.py |   12 +----
 funasr/cli/trainer.py   |   78 +++++++++-----------------------------
 2 files changed, 22 insertions(+), 68 deletions(-)

diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
index 687887f..ed62773 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/cli/train_cli.py
@@ -46,7 +46,7 @@
 	
 	local_rank = int(os.environ.get('LOCAL_RANK', 0))
 	# Check if we are using DDP or FSDP
-	use_ddp = 'WORLD_SIZE' in os.environ
+	use_ddp = 'WORLD_SIZE' in os.environ and os.environ["WORLD_SIZE"] > 1
 	use_fsdp = kwargs.get("use_fsdp", None)
 	if use_ddp or use_fsdp:
 		dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
@@ -109,7 +109,8 @@
 
 	if use_ddp:
 		model = model.cuda(local_rank)
-		model = DDP(model, device_ids=[local_rank])
+		model = DDP(model, device_ids=[local_rank],
+		            find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
 	elif use_fsdp:
 		model = FSDP(model).cuda(local_rank)
 	else:
@@ -157,13 +158,6 @@
 		torch.distributed.destroy_process_group()
 
 	
-	
-def train(epoch, model, op):
-	pass
-
-def val():
-	pass
-
 
 if __name__ == "__main__":
 	main()
\ No newline at end of file
diff --git a/funasr/cli/trainer.py b/funasr/cli/trainer.py
index ee5af0f..686f387 100644
--- a/funasr/cli/trainer.py
+++ b/funasr/cli/trainer.py
@@ -5,6 +5,7 @@
 from tqdm import tqdm
 from contextlib import nullcontext
 import torch.distributed as dist
+from funasr.torch_utils.recursive_op import recursive_average
 
 class Trainer:
 	"""
@@ -56,6 +57,8 @@
 		self.start_epoch = 1
 		self.max_epoch = kwargs.get('max_epoch', 100)
 		self.local_rank = local_rank
+		self.rank = dist.get_rank()
+		self.world_size = dist.get_world_size()
 		self.use_ddp = use_ddp
 		self.use_fsdp = use_fsdp
 		self.device = torch.device("cuda", local_rank)
@@ -113,7 +116,7 @@
 			# self._validate_epoch(epoch)
 			if dist.get_rank() == 0:
 				self._save_checkpoint(epoch)
-			# self.scheduler.step()
+			self.scheduler.step()
 	
 	def _train_epoch(self, epoch):
 		"""
@@ -126,24 +129,34 @@
 		            dynamic_ncols=True)
 		
 		# Set the number of steps for gradient accumulation
-		accumulation_steps = self.kwargs.get("accumulation_steps", 1)
+		accum_grad = self.kwargs.get("accum_grad", 1)
 		# Initialize the gradient accumulation
 		self.optim.zero_grad()
 		
 		for batch_idx, batch in enumerate(self.dataloader_train):
 			batch = to_device(batch, self.device)
 			
-			my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
+			my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext
 			with my_context():
 				retval = self.model(**batch)
 				loss, stats, weight = retval
-				
+				stats = {k: v for k, v in stats.items() if v is not None}
+				if self.use_ddp or self.use_fsdp:
+					# Apply weighted averaging for loss and stats
+					loss = (loss * weight.type(loss.dtype)).sum()
+					# if distributed, this method can also apply all_reduce()
+					stats, weight = recursive_average(stats, weight, distributed=True)
+					# Now weight is summation over all workers
+					loss /= weight
+					# Multiply world_size because DistributedDataParallel
+					# automatically normalizes the gradient by world_size.
+					loss *= self.world_size
 				# Scale the loss since we're not updating for every mini-batch
-				loss = loss / accumulation_steps
+				loss = loss / accum_grad
 				loss.backward()
 			
 			# Perform an optimizer step only after accumulating enough gradients
-			if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.dataloader_train):
+			if (batch_idx + 1) % accum_grad == 0 or (batch_idx + 1) == len(self.dataloader_train):
 				# Perform gradient clipping if it is set
 				if self.kwargs.get("grad_clip", None) is not None:
 					grad_norm = torch.nn.utils.clip_grad_norm_(
@@ -170,43 +183,6 @@
 					f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
 			
 		pbar.close()
-	
-	# def _train_epoch(self, epoch):
-	# 	"""
-	# 	Defines the training process for a single epoch.
-	# 	Should be implemented with the actual model training steps.
-	#
-	# 	Args:
-	# 		epoch (int): The current epoch number.
-	# 	"""
-	# 	self.model.train()
-	# 	pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_train), dynamic_ncols=True)
-	# 	for batch_idx, batch in enumerate(self.dataloader_train):
-	# 		batch = to_device(batch, "cpu")
-	# 		retval = self.model(**batch)
-	# 		loss, stats, weight = retval
-	# 		self.optim.zero_grad()
-	# 		loss.backward()
-	#
-	# 		# compute the gradient norm to check if it is normal or not
-	# 		grad_norm = torch.nn.utils.clip_grad_norm_(
-	# 			self.model.parameters(),
-	# 			max_norm=self.kwargs.get("grad_clip", 10.0),
-	# 			norm_type=self.kwargs.get("grad_clip_type", 2.0),
-	# 		)
-	# 		if not torch.isfinite(grad_norm):
-	# 			logging.warning(
-	# 				f"The grad norm is {grad_norm}. Skipping updating the model."
-	# 			)
-	# 			continue
-	# 		self.optim.step()
-	# 		self.scheduler.step()
-	# 		pbar.update(1)
-	# 		pbar.set_description(
-	# 			f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
-	#
-	# 	pbar.close()
-	#
 
 	def _validate_epoch(self, epoch):
 		"""
@@ -221,19 +197,3 @@
 			for data, target in self.dataloader_val:
 				# Implement the model validation steps here
 				pass
-
-# # Example usage
-# if __name__ == "__main__":
-# 	# Assuming the following objects have already been correctly created and initialized:
-# 	# model, optim, scheduler, dataloader_train, and dataloader_val.
-# 	trainer = Trainer(
-# 	    max_epoch=10,
-# 	    model=model,
-# 	    optim=optim,
-# 	    scheduler=scheduler,
-# 	    dataloader_train=dataloader_train,
-# 	    dataloader_val=dataloader_val,
-# 	    output_dir='path_to_save_model',
-# 	    resume='path_to_checkpoint_if_any'
-# 	)
-# 	trainer.run()
\ No newline at end of file

--
Gitblit v1.9.1