From 0cf5dfec2c8313fc2ed2aab8d10bf3dc4b9c283f Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 14 三月 2024 14:41:49 +0800
Subject: [PATCH] update cmakelist
---
funasr/datasets/llm_datasets_vicuna/samplers.py | 9 +++++++++
1 files changed, 9 insertions(+), 0 deletions(-)
diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
index 512250d..fe840e2 100644
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -142,6 +142,15 @@
def set_epoch(self, epoch):
self.epoch = epoch
+
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+ dataloader_args = {"dataset": dataset}
+ dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
+ dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+ dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+
+ return dataloader_args
+
@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
class CustomDistributedBatchSampler(Sampler):
def __init__(self, dataset,
--
Gitblit v1.9.1