| | |
| | | # self.kwargs = kwargs |
| | | self.max_token_length = kwargs.get("max_token_length", 1024) |
| | | self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5) |
| | | self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500) |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | |
| | | b, t = outputs["input_ids"].shape |
| | | if b > 1 and b * t > self.batch_size * self.batch_size_scale_ratio_max: |
| | | logging.info( |
| | | f"Warning, b*t: {b}*{t}={b * t} > batch_size*relax: {self.batch_size_scale_ratio_max}*{self.batch_size}={self.batch_size_scale_ratio_max*self.batch_size}, drop last data" |
| | | f"Warning, {idx}th, b*t: {b}*{t}={b * t} > batch_size_sample_max: {self.batch_size_token_max}, drop last data" |
| | | ) |
| | | samples = samples[:-1] |
| | | continue |