From 32e783664534bbb8d3b8ba64c2c2ecb42398eb00 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 09:54:35 +0800
Subject: [PATCH] update with main (#1786)

---
 funasr/datasets/audio_datasets/espnet_samplers.py |   40 ++++++++++++++++++++++++++++++----------
 1 files changed, 30 insertions(+), 10 deletions(-)

diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index e6efe0a..004201e 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -41,6 +41,7 @@
         drop_last=False,
         is_training: bool = True,
         sort_size: int = 1024,
+        start_step: int = 0,
         **kwargs,
     ):
 
@@ -70,7 +71,10 @@
         self.max_token_length = kwargs.get("max_token_length", 2048)
         self.min_token_length = kwargs.get("min_token_length", 0)
         self.length_scale_source = kwargs.get("length_scale_source", 1.0)
-
+        self.start_step = start_step
+        self.batch_num = 1
+        if self.start_step > 0:
+            logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}")
         # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
         #                  shuffle=shuffle, drop_last=drop_last)
 
@@ -92,14 +96,25 @@
         max_len_in_batch = 0  # Tracks the max sample length within the current batch
 
         for idx in sorted_indices:
-            original_sample_length = self.dataset.get_source_len(idx)
-            if (
-                original_sample_length < self.min_token_length
-                or original_sample_length > self.max_token_length
-            ):  # Skip samples that exceed the max length
-                continue
+
+            # original_sample_length = self.dataset.get_source_len(idx)
+            # if (
+            #     original_sample_length < self.min_token_length
+            #     or original_sample_length > self.max_token_length
+            # ):  # Skip samples that exceed the max length
+            #     continue
+
+            # sample_length = 1 if self.batch_type == "example" else original_sample_length
+
             # Set sample_length based on the batch type
-            sample_length = 1 if self.batch_type == "example" else original_sample_length
+            if self.batch_type == "example":
+                sample_length = 1
+            elif self.batch_type == "token":
+                sample_length = self.dataset.get_source_len(idx) + int(
+                    self.dataset.get_target_len(idx) * 1.2
+                )
+            else:
+                sample_length = self.dataset.get_source_len(idx)
             # Calculate potential batch size with the new sample
             potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1)
             # Add index to batch if it doesn't exceed batch size limit
@@ -131,14 +146,19 @@
         # Allocate the batches to the current rank
         start_idx = self.rank * batches_per_rank
         end_idx = start_idx + batches_per_rank
-        rank_batches = buffer_batches[start_idx:end_idx]
+        rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
 
+        self.batch_num = len(rank_batches)
+
+        logging.info(
+            f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
+        )
         # Return an iterator over the batches for the current rank
         return iter(rank_batches)
 
     def __len__(self):
         # Calculate the number of batches per epoch for the current rank
-        return 1
+        return self.batch_num
 
     def set_epoch(self, epoch):
         # Set the epoch for shuffling

--
Gitblit v1.9.1