zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/datasets/audio_datasets/samplers.py
@@ -36,8 +36,11 @@
    
    return dataloader_args
class CustomDistributedBatchSampler(Sampler):
    def __init__(self, dataset,
    def __init__(
        self,
        dataset,
                 batch_size,
                 num_replicas=None,
                 rank=None,
@@ -62,9 +65,13 @@
        self.drop_last = drop_last
        # self.total_size = len(dataset)
        if self.drop_last:
            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (
                batch_size * num_replicas
            )
        else:
            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (
                batch_size * num_replicas
            )
        self.num_samples = int(self.total_size // self.num_replicas)
        self.epoch = 0
        self.max_token_length = kwargs.get("max_token_length", None)
@@ -84,7 +91,9 @@
        if padding_size <= len(indices):
            indices += indices[:padding_size]
        else:
            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
            indices += (
                indices * (padding_size // len(indices)) + indices[: padding_size % len(indices)]
            )
        assert len(indices) == self.total_size
@@ -102,7 +111,9 @@
            indices = filtered_indices
        # Now that we have only the indices for this replica, chunk them into batches
        batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
        batches = [
            indices[i : i + self.batch_size] for i in range(0, len(indices), self.batch_size)
        ]
        # Drop the last batch if it's not full and drop_last is True
        if self.drop_last and len(batches[-1]) != self.batch_size:
@@ -117,8 +128,11 @@
    def set_epoch(self, epoch):
        self.epoch = epoch
class CustomDistributedBufferBatchSampler(Sampler):
    def __init__(self, dataset,
    def __init__(
        self,
        dataset,
                 batch_size,
                 num_replicas=None,
                 rank=None,
@@ -144,9 +158,13 @@
        self.drop_last = drop_last
        # self.total_size = len(dataset)
        if self.drop_last:
            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (batch_size * num_replicas)
            self.total_size = (len(self.dataset) // (batch_size * num_replicas)) * (
                batch_size * num_replicas
            )
        else:
            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (batch_size * num_replicas)
            self.total_size = math.ceil(len(self.dataset) / (batch_size * num_replicas)) * (
                batch_size * num_replicas
            )
        self.num_samples = int(self.total_size // self.num_replicas)
        self.epoch = 0
        self.max_token_length = kwargs.get("max_token_length", None)
@@ -167,7 +185,9 @@
        if padding_size <= len(indices):
            indices += indices[:padding_size]
        else:
            indices += (indices * (padding_size // len(indices)) + indices[:padding_size % len(indices)])
            indices += (
                indices * (padding_size // len(indices)) + indices[: padding_size % len(indices)]
            )
        
        assert len(indices) == self.total_size
        
@@ -205,7 +225,9 @@
    def _create_batches_from_buffer(self, buffer):
        # Function to convert the sorted buffer into batches
        batched_buffer = [buffer[i:i + self.batch_size] for i in range(0, len(buffer), self.batch_size)]
        batched_buffer = [
            buffer[i : i + self.batch_size] for i in range(0, len(buffer), self.batch_size)
        ]
        if self.drop_last and len(batched_buffer[-1]) != self.batch_size:
            batched_buffer = batched_buffer[:-1]
        return batched_buffer
@@ -217,8 +239,11 @@
    def set_epoch(self, epoch):
        self.epoch = epoch
class CustomDistributedDynamicBatchSampler(DistributedSampler):
    def __init__(self, dataset,
    def __init__(
        self,
        dataset,
                 batch_size,
                 num_replicas=None,
                 rank=None,
@@ -267,8 +292,9 @@
            sample_length = self.dataset.get_source_len(idx)
            if sample_length > self.max_token_length:
                continue
            potential_batch_length = (max_len_in_batch if sample_length < max_len_in_batch else sample_length) * (
                    len(batch) + 1)
            potential_batch_length = (
                max_len_in_batch if sample_length < max_len_in_batch else sample_length
            ) * (len(batch) + 1)
            
            if potential_batch_length <= self.batch_size:
                batch.append(idx)
@@ -296,7 +322,9 @@
class CustomDistributedBufferDynamicBatchSampler(DistributedSampler):
    def __init__(self, dataset,
    def __init__(
        self,
        dataset,
                 batch_size,
                 batch_type="token",
                 num_replicas=None,
@@ -336,8 +364,9 @@
        self.sort_size = sort_size * num_replicas
        self.max_token_length = kwargs.get("max_token_length", 2048)
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        super().__init__(dataset, num_replicas=num_replicas, rank=rank,
                         shuffle=shuffle, drop_last=drop_last)
        super().__init__(
            dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last
        )
    def __iter__(self):
        if self.shuffle:
@@ -352,7 +381,9 @@
        # Create sorted buffers and form batches
        buffer_batches = []
        for i in range(0, len(indices), self.sort_size):
            buffer = sorted(indices[i:i + self.sort_size], key=lambda idx: self.dataset.get_source_len(idx))
            buffer = sorted(
                indices[i : i + self.sort_size], key=lambda idx: self.dataset.get_source_len(idx)
            )
            batch = []
            max_len_in_batch = 0
            for idx in buffer:
@@ -388,7 +419,6 @@
        return iter(final_batches)
    def __len__(self):
        
        return 1
@@ -398,7 +428,9 @@
class DistributedSamplerWarp(BatchSampler):
    def __init__(self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False):
    def __init__(
        self, dataset, batch_size, num_replicas=None, rank=None, shuffle=True, drop_last=False
    ):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
@@ -417,10 +449,7 @@
        
        # Create an instance of the DistributedSampler
        self.sampler = DistributedSampler(
            self.dataset,
            num_replicas=self.num_replicas,
            rank=self.rank,
            shuffle=self.shuffle
            self.dataset, num_replicas=self.num_replicas, rank=self.rank, shuffle=self.shuffle
        )
        
        # Call BatchSampler's constructor