zhifu gao
2024-04-25 80bd14e6bbb7bb282ff3832194648dc4a16157ca
funasr/datasets/audio_datasets/espnet_samplers.py
@@ -92,14 +92,25 @@
        max_len_in_batch = 0  # Tracks the max sample length within the current batch
        for idx in sorted_indices:
            original_sample_length = self.dataset.get_source_len(idx)
            if (
                original_sample_length < self.min_token_length
                or original_sample_length > self.max_token_length
            ):  # Skip samples that exceed the max length
                continue
            # original_sample_length = self.dataset.get_source_len(idx)
            # if (
            #     original_sample_length < self.min_token_length
            #     or original_sample_length > self.max_token_length
            # ):  # Skip samples that exceed the max length
            #     continue
            # sample_length = 1 if self.batch_type == "example" else original_sample_length
            # Set sample_length based on the batch type
            sample_length = 1 if self.batch_type == "example" else original_sample_length
            if self.batch_type == "example":
                sample_length = 1
            elif self.batch_type == "token":
                sample_length = self.dataset.get_source_len(idx) + int(
                    self.dataset.get_target_len(idx) * 1.2
                )
            else:
                sample_length = self.dataset.get_source_len(idx)
            # Calculate potential batch size with the new sample
            potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1)
            # Add index to batch if it doesn't exceed batch size limit