雾聪
2024-01-25 d1efd59af963a25314dbbe254d298ed441695ca1
funasr/datasets/audio_datasets/samplers.py
@@ -13,6 +13,7 @@
                 buffer_size: int = 30,
                 drop_last: bool = False,
                 shuffle: bool = True,
                 is_training: bool = True,
                 **kwargs):
        
        self.drop_last = drop_last
@@ -20,14 +21,14 @@
        self.dataset = dataset
        self.total_samples = len(dataset)
        self.batch_type = batch_type
        self.batch_size = batch_size
        self.batch_size = int(batch_size)
        self.buffer_size = buffer_size
        self.max_token_length = kwargs.get("max_token_length", 5000)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle
        self.shuffle = shuffle and is_training
    
    def __len__(self):
        return self.total_samples
        return (self.total_samples-1) // self.batch_size + 1
    
    def set_epoch(self, epoch):
        np.random.seed(epoch)