From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/datasets/audio_datasets/espnet_samplers.py | 99 ++++++++++++++++++++++++++++++++-----------------
1 files changed, 64 insertions(+), 35 deletions(-)
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index 1524a6a..004201e 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -18,7 +18,7 @@
dataloader_args["batch_sampler"] = batch_sampler
dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
-
+
return dataloader_args
@@ -29,17 +29,21 @@
class EspnetStyleBatchSampler(DistributedSampler):
- def __init__(self, dataset,
- batch_size,
- batch_type="token",
- num_replicas=None,
- rank=None,
- shuffle=True,
- drop_last=False,
- is_training: bool = True,
- sort_size: int = 1024,
- **kwargs,
- ):
+ def __init__(
+ self,
+ dataset,
+ batch_size,
+ batch_type="token",
+ rank=None,
+ num_replicas=None,
+ rank_split=False,
+ shuffle=True,
+ drop_last=False,
+ is_training: bool = True,
+ sort_size: int = 1024,
+ start_step: int = 0,
+ **kwargs,
+ ):
try:
rank = dist.get_rank()
@@ -47,6 +51,10 @@
except:
rank = 0
num_replicas = 1
+ # if rank_split:
+ # logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
+ # rank = 0
+ # num_replicas = 1
self.rank = rank
self.num_replicas = num_replicas
self.dataset = dataset
@@ -56,16 +64,20 @@
self.shuffle = shuffle and is_training
self.drop_last = drop_last
- # self.total_size = len(self.dataset)
- # self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
+ self.total_size = len(self.dataset)
+ self.num_samples = int(math.ceil(self.total_size / self.num_replicas))
self.epoch = 0
self.sort_size = sort_size * num_replicas
self.max_token_length = kwargs.get("max_token_length", 2048)
+ self.min_token_length = kwargs.get("min_token_length", 0)
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
+ self.start_step = start_step
+ self.batch_num = 1
+ if self.start_step > 0:
+ logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
+ # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
+ # shuffle=shuffle, drop_last=drop_last)
-
- super().__init__(dataset, num_replicas=num_replicas, rank=rank,
- shuffle=shuffle, drop_last=drop_last)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
@@ -74,21 +86,35 @@
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
-
+
# Sort indices by sample length
sorted_indices = sorted(indices, key=lambda idx: self.dataset.get_source_len(idx))
-
+
# Organize batches based on 'length' or 'example'
buffer_batches = []
batch = []
max_len_in_batch = 0 # Tracks the max sample length within the current batch
-
+
for idx in sorted_indices:
- original_sample_length = self.dataset.get_source_len(idx)
- if original_sample_length > self.max_token_length: # Skip samples that exceed the max length
- continue
+
+ # original_sample_length = self.dataset.get_source_len(idx)
+ # if (
+ # original_sample_length < self.min_token_length
+ # or original_sample_length > self.max_token_length
+ # ): # Skip samples that exceed the max length
+ # continue
+
+ # sample_length = 1 if self.batch_type == "example" else original_sample_length
+
# Set sample_length based on the batch type
- sample_length = 1 if self.batch_type == "example" else original_sample_length
+ if self.batch_type == "example":
+ sample_length = 1
+ elif self.batch_type == "token":
+ sample_length = self.dataset.get_source_len(idx) + int(
+ self.dataset.get_target_len(idx) * 1.2
+ )
+ else:
+ sample_length = self.dataset.get_source_len(idx)
# Calculate potential batch size with the new sample
potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1)
# Add index to batch if it doesn't exceed batch size limit
@@ -100,37 +126,40 @@
buffer_batches.append(batch)
batch = [idx]
max_len_in_batch = sample_length
-
+
# Add the last batch if it shouldn't be dropped
if batch and (not self.drop_last or len(batch) * max_len_in_batch == self.batch_size):
buffer_batches.append(batch)
-
+
# Shuffle the list of batches
if self.shuffle:
random.seed(self.epoch)
random.shuffle(buffer_batches)
-
+
# Ensure each rank gets the same number of batches
batches_per_rank = int(math.ceil(len(buffer_batches) / self.num_replicas))
total_batches_needed = batches_per_rank * self.num_replicas
extra_batches = total_batches_needed - len(buffer_batches)
# Add extra batches by random selection, if needed
buffer_batches += random.choices(buffer_batches, k=extra_batches)
-
+
# Allocate the batches to the current rank
start_idx = self.rank * batches_per_rank
end_idx = start_idx + batches_per_rank
- rank_batches = buffer_batches[start_idx:end_idx]
-
+ rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
+
+ self.batch_num = len(rank_batches)
+
+ logging.info(
+ f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
+ )
# Return an iterator over the batches for the current rank
return iter(rank_batches)
-
+
def __len__(self):
# Calculate the number of batches per epoch for the current rank
- return 1
-
+ return self.batch_num
+
def set_epoch(self, epoch):
# Set the epoch for shuffling
self.epoch = epoch
-
-
--
Gitblit v1.9.1