From 8912e0696af069de47646fdb8a9d9c4e086e88b3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 14 一月 2024 23:42:11 +0800
Subject: [PATCH] Resolve merge conflict

---
 funasr/datasets/audio_datasets/samplers.py |   29 ++++++++++++++---------------
 1 files changed, 14 insertions(+), 15 deletions(-)

diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index d34fdea..9c87245 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -4,15 +4,16 @@
 
 from funasr.register import tables
 
+
 @tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
 class BatchSampler(torch.utils.data.BatchSampler):
 	
 	def __init__(self, dataset,
-	             batch_type: str="example",
-	             batch_size: int=100,
-	             buffer_size: int=30,
-	             drop_last: bool=False,
-	             shuffle: bool=True,
+	             batch_type: str = "example",
+	             batch_size: int = 100,
+	             buffer_size: int = 30,
+	             drop_last: bool = False,
+	             shuffle: bool = True,
 	             **kwargs):
 		
 		self.drop_last = drop_last
@@ -25,24 +26,23 @@
 		self.max_token_length = kwargs.get("max_token_length", 5000)
 		self.shuffle_idx = np.arange(self.total_samples)
 		self.shuffle = shuffle
-
 	
 	def __len__(self):
 		return self.total_samples
 	
 	def set_epoch(self, epoch):
 		np.random.seed(epoch)
-		
+	
 	def __iter__(self):
 		
 		if self.shuffle:
 			np.random.shuffle(self.shuffle_idx)
-			
+		
 		batch = []
 		max_token = 0
 		num_sample = 0
-
-		iter_num = (self.total_samples-1) // self.buffer_size + 1
+		
+		iter_num = (self.total_samples - 1) // self.buffer_size + 1
 		# print("iter_num: ", iter_num)
 		for iter in range(self.pre_idx + 1, iter_num):
 			datalen_with_index = []
@@ -50,12 +50,12 @@
 				idx = iter * self.buffer_size + i
 				if idx >= self.total_samples:
 					continue
-
+				
 				idx_map = self.shuffle_idx[idx]
 				# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
 				sample_len_cur = self.dataset.get_source_len(idx_map) + \
 				                 self.dataset.get_target_len(idx_map)
-
+				
 				datalen_with_index.append([idx, sample_len_cur])
 			
 			datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
@@ -63,7 +63,7 @@
 				idx, sample_len_cur_raw = item
 				if sample_len_cur_raw > self.max_token_length:
 					continue
-
+				
 				max_token_cur = max(max_token, sample_len_cur_raw)
 				max_token_padding = 1 + num_sample
 				if self.batch_type == 'length':
@@ -77,5 +77,4 @@
 					batch = [idx]
 					max_token = sample_len_cur_raw
 					num_sample = 1
-					
-		
\ No newline at end of file
+

--
Gitblit v1.9.1