From fcbbe8af9f22a41611d9506af17cae1e422f9fec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 22 三月 2024 19:25:29 +0800
Subject: [PATCH] update
---
funasr/bin/compute_audio_cmvn.py | 23 ++++++++---------------
1 files changed, 8 insertions(+), 15 deletions(-)
diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py
index ffad652..6282e70 100644
--- a/funasr/bin/compute_audio_cmvn.py
+++ b/funasr/bin/compute_audio_cmvn.py
@@ -28,7 +28,7 @@
def main(**kwargs):
print(kwargs)
# set random seed
- tables.print()
+ # tables.print()
set_all_random_seed(kwargs.get("seed", 0))
torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -54,21 +54,14 @@
dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
# dataloader
- batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
- batch_sampler_train = None
- if batch_sampler is not None:
- batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
- dataset_conf = kwargs.get("dataset_conf")
- dataset_conf["batch_type"] = "example"
- dataset_conf["batch_size"] = 1
- batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
+ batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
+ batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+ dataset_conf = kwargs.get("dataset_conf")
+ dataset_conf["batch_type"] = "example"
+ dataset_conf["batch_size"] = 1
+ batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
-
- dataloader_train = torch.utils.data.DataLoader(dataset_train,
- collate_fn=dataset_train.collator,
- batch_sampler=batch_sampler_train,
- num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
- pin_memory=True)
+ dataloader_train = torch.utils.data.DataLoader(dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train)
iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
--
Gitblit v1.9.1