From c8bae0ec85eee25d66de6b1e4502eff74d750b24 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 十二月 2023 13:29:37 +0800
Subject: [PATCH] funasr2
---
funasr/bin/train.py | 4 ++--
1 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8112002..1e06c50 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -145,7 +145,8 @@
# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
- batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+ if batch_sampler is not None:
+ batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
@@ -153,7 +154,6 @@
pin_memory=True)
-
trainer = Trainer(
model=model,
optim=optim,
--
Gitblit v1.9.1