From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/bin/compute_audio_cmvn.py |   26 ++++++++++----------------
 1 files changed, 10 insertions(+), 16 deletions(-)

diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py
index 4561bec..cd64329 100644
--- a/funasr/bin/compute_audio_cmvn.py
+++ b/funasr/bin/compute_audio_cmvn.py
@@ -18,7 +18,7 @@
 
     assert "model" in kwargs
     if "model_conf" not in kwargs:
-        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
         kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
     
 
@@ -28,7 +28,7 @@
 def main(**kwargs):
     print(kwargs)
     # set random seed
-    tables.print()
+    # tables.print()
     set_all_random_seed(kwargs.get("seed", 0))
     torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
     torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -54,21 +54,15 @@
     dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
 
     # dataloader
-    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-    batch_sampler_train = None
-    if batch_sampler is not None:
-        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
-        dataset_conf = kwargs.get("dataset_conf")
-        dataset_conf["batch_type"] = "example"
-        dataset_conf["batch_size"] = 1
-        batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
+    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
+    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+    dataset_conf = kwargs.get("dataset_conf")
+    dataset_conf["batch_type"] = "example"
+    dataset_conf["batch_size"] = 1
+    dataset_conf["num_workers"] = os.cpu_count() or 32
+    batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
 
-    
-    dataloader_train = torch.utils.data.DataLoader(dataset_train,
-                                                collate_fn=dataset_train.collator,
-                                                batch_sampler=batch_sampler_train,
-                                                num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
-                                                pin_memory=True)
+    dataloader_train = torch.utils.data.DataLoader(dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train)
     
     iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
 

--
Gitblit v1.9.1