游雁
2023-11-27 54a91194901ad72562d5cb5856ee8c302d93fb0e
funasr/datasets/data_sampler.py
@@ -4,7 +4,7 @@
class BatchSampler(torch.utils.data.BatchSampler):
   
   def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
   def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
      
      self.drop_last = drop_last
      self.pre_idx = -1
@@ -46,8 +46,8 @@
            idx_map = self.shuffle_idx[idx]
            # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
            sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
                             self.dataset.indexed_dataset[idx_map]["target_len"]
            sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \
                             self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map])
            datalen_with_index.append([idx, sample_len_cur])