From c8bae0ec85eee25d66de6b1e4502eff74d750b24 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 十二月 2023 13:29:37 +0800
Subject: [PATCH] funasr2

---
 funasr/bin/train.py |    4 ++--
 1 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8112002..1e06c50 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -145,7 +145,8 @@
 	# dataloader
 	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
 	batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
-	batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+	if batch_sampler is not None:
+		batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
 	dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
 	                                            collate_fn=dataset_tr.collator,
 	                                            batch_sampler=batch_sampler,
@@ -153,7 +154,6 @@
 	                                            pin_memory=True)
 	
 
-	
 	trainer = Trainer(
 	    model=model,
 	    optim=optim,

--
Gitblit v1.9.1