From fcbbe8af9f22a41611d9506af17cae1e422f9fec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 22 三月 2024 19:25:29 +0800
Subject: [PATCH] update

---
 examples/aishell/conformer/conf/conformer_12e_6d_2048_256.yaml             |    2 
 examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml   |    2 
 examples/aishell/transformer/conf/transformer_12e_6d_2048_256.yaml         |    2 
 funasr/models/monotonic_aligner/template.yaml                              |    2 
 funasr/models/bicif_paraformer/template.yaml                               |    2 
 examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml       |    2 
 funasr/models/branchformer/template.yaml                                   |    2 
 examples/industrial_data_pretraining/llm_asr/conf/template.yaml            |    2 
 funasr/models/conformer/template.yaml                                      |    2 
 funasr/models/transformer/template.yaml                                    |    2 
 funasr/bin/compute_audio_cmvn.py                                           |   23 ++++-------
 funasr/models/contextual_paraformer/template.yaml                          |    2 
 funasr/models/e_branchformer/template.yaml                                 |    2 
 funasr/datasets/dataloader_entry.py                                        |    2 
 funasr/datasets/audio_datasets/scp2jsonl.py                                |    1 
 examples/aishell/paraformer/conf/paraformer_conformer_12e_6d_2048_256.yaml |    2 
 funasr/models/scama/template.yaml                                          |    2 
 funasr/models/paraformer/template.yaml                                     |    2 
 funasr/datasets/audio_datasets/jsonl2scp.py                                |    1 
 funasr/bin/train.py                                                        |    2 
 funasr/models/seaco_paraformer/template.yaml                               |    2 
 funasr/models/sanm/template.yaml                                           |    2 
 funasr/models/paraformer_streaming/template.yaml                           |    2 
 funasr/models/uniasr/template.yaml                                         |    2 
 24 files changed, 31 insertions(+), 36 deletions(-)

diff --git a/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml b/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
index d86e628..61f815e 100644
--- a/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
@@ -94,7 +94,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: RankFullLocalShuffleBatchSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/examples/aishell/conformer/conf/conformer_12e_6d_2048_256.yaml b/examples/aishell/conformer/conf/conformer_12e_6d_2048_256.yaml
index 4238d84..12eeb3e 100644
--- a/examples/aishell/conformer/conf/conformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/conformer/conf/conformer_12e_6d_2048_256.yaml
@@ -94,7 +94,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: RankFullLocalShuffleBatchSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml b/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
index 2773459..df13e67 100644
--- a/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
@@ -94,7 +94,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: RankFullLocalShuffleBatchSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/examples/aishell/paraformer/conf/paraformer_conformer_12e_6d_2048_256.yaml b/examples/aishell/paraformer/conf/paraformer_conformer_12e_6d_2048_256.yaml
index c100f0d..5625623 100644
--- a/examples/aishell/paraformer/conf/paraformer_conformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/paraformer/conf/paraformer_conformer_12e_6d_2048_256.yaml
@@ -93,7 +93,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: RankFullLocalShuffleBatchSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/examples/aishell/transformer/conf/transformer_12e_6d_2048_256.yaml b/examples/aishell/transformer/conf/transformer_12e_6d_2048_256.yaml
index 674fc94..51936c1 100644
--- a/examples/aishell/transformer/conf/transformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/transformer/conf/transformer_12e_6d_2048_256.yaml
@@ -88,7 +88,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: RankFullLocalShuffleBatchSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 32 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/examples/industrial_data_pretraining/llm_asr/conf/template.yaml b/examples/industrial_data_pretraining/llm_asr/conf/template.yaml
index 3c51ff4..c64c886 100644
--- a/examples/industrial_data_pretraining/llm_asr/conf/template.yaml
+++ b/examples/industrial_data_pretraining/llm_asr/conf/template.yaml
@@ -73,7 +73,7 @@
 dataset: AudioLLMDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: RankFullLocalShuffleBatchSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 8 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py
index ffad652..6282e70 100644
--- a/funasr/bin/compute_audio_cmvn.py
+++ b/funasr/bin/compute_audio_cmvn.py
@@ -28,7 +28,7 @@
 def main(**kwargs):
     print(kwargs)
     # set random seed
-    tables.print()
+    # tables.print()
     set_all_random_seed(kwargs.get("seed", 0))
     torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
     torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -54,21 +54,14 @@
     dataset_train = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=None, is_training=False, **kwargs.get("dataset_conf"))
 
     # dataloader
-    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-    batch_sampler_train = None
-    if batch_sampler is not None:
-        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
-        dataset_conf = kwargs.get("dataset_conf")
-        dataset_conf["batch_type"] = "example"
-        dataset_conf["batch_size"] = 1
-        batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
+    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
+    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+    dataset_conf = kwargs.get("dataset_conf")
+    dataset_conf["batch_type"] = "example"
+    dataset_conf["batch_size"] = 1
+    batch_sampler_train = batch_sampler_class(dataset_train, is_training=False, **dataset_conf)
 
-    
-    dataloader_train = torch.utils.data.DataLoader(dataset_train,
-                                                collate_fn=dataset_train.collator,
-                                                batch_sampler=batch_sampler_train,
-                                                num_workers=int(kwargs.get("dataset_conf").get("num_workers", 4)),
-                                                pin_memory=True)
+    dataloader_train = torch.utils.data.DataLoader(dataset_train, collate_fn=dataset_train.collator, **batch_sampler_train)
     
     iter_stop = int(kwargs.get("scale", 1.0)*len(dataloader_train))
 
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 3f97f9e..1db37dd 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -148,7 +148,7 @@
 
     # dataset
     logging.info("Build dataloader")
-    dataloader_class = tables.dataloader_classes.get( kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
+    dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
     dataloader_tr, dataloader_val = dataloader_class(**kwargs)
 
     trainer = Trainer(local_rank=local_rank,
diff --git a/funasr/datasets/audio_datasets/jsonl2scp.py b/funasr/datasets/audio_datasets/jsonl2scp.py
index 9a2b023..b2d1703 100644
--- a/funasr/datasets/audio_datasets/jsonl2scp.py
+++ b/funasr/datasets/audio_datasets/jsonl2scp.py
@@ -40,6 +40,7 @@
 def main_hydra(cfg: DictConfig):
  
     kwargs = OmegaConf.to_container(cfg, resolve=True)
+    print(kwargs)
 
     scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
     if isinstance(scp_file_list, str):
diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index 00754dd..a537127 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -79,6 +79,7 @@
 def main_hydra(cfg: DictConfig):
  
     kwargs = OmegaConf.to_container(cfg, resolve=True)
+    print(kwargs)
 
     scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
     if isinstance(scp_file_list, str):
diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py
index 21e3834..a1e4da2 100644
--- a/funasr/datasets/dataloader_entry.py
+++ b/funasr/datasets/dataloader_entry.py
@@ -13,7 +13,7 @@
 	dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
 	
 	# dataloader
-	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "BatchSampler")
 	batch_sampler_val = None
 	if batch_sampler is not None:
 		batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
diff --git a/funasr/models/bicif_paraformer/template.yaml b/funasr/models/bicif_paraformer/template.yaml
index d2b0e0a..db7ce55 100644
--- a/funasr/models/bicif_paraformer/template.yaml
+++ b/funasr/models/bicif_paraformer/template.yaml
@@ -112,7 +112,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/branchformer/template.yaml b/funasr/models/branchformer/template.yaml
index 09b1e0f..0fec497 100644
--- a/funasr/models/branchformer/template.yaml
+++ b/funasr/models/branchformer/template.yaml
@@ -94,7 +94,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/conformer/template.yaml b/funasr/models/conformer/template.yaml
index f646acc..5a8687e 100644
--- a/funasr/models/conformer/template.yaml
+++ b/funasr/models/conformer/template.yaml
@@ -95,7 +95,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/contextual_paraformer/template.yaml b/funasr/models/contextual_paraformer/template.yaml
index fbd15ce..2205250 100644
--- a/funasr/models/contextual_paraformer/template.yaml
+++ b/funasr/models/contextual_paraformer/template.yaml
@@ -108,7 +108,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/e_branchformer/template.yaml b/funasr/models/e_branchformer/template.yaml
index dea580e..7b0b27c 100644
--- a/funasr/models/e_branchformer/template.yaml
+++ b/funasr/models/e_branchformer/template.yaml
@@ -94,7 +94,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/monotonic_aligner/template.yaml b/funasr/models/monotonic_aligner/template.yaml
index b1379de..f8d5ded 100644
--- a/funasr/models/monotonic_aligner/template.yaml
+++ b/funasr/models/monotonic_aligner/template.yaml
@@ -93,7 +93,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml
index bccf638..7809457 100644
--- a/funasr/models/paraformer/template.yaml
+++ b/funasr/models/paraformer/template.yaml
@@ -100,7 +100,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/paraformer_streaming/template.yaml b/funasr/models/paraformer_streaming/template.yaml
index d1300ac..889971a 100644
--- a/funasr/models/paraformer_streaming/template.yaml
+++ b/funasr/models/paraformer_streaming/template.yaml
@@ -121,7 +121,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/sanm/template.yaml b/funasr/models/sanm/template.yaml
index 156926f..a7f7b12 100644
--- a/funasr/models/sanm/template.yaml
+++ b/funasr/models/sanm/template.yaml
@@ -98,7 +98,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/scama/template.yaml b/funasr/models/scama/template.yaml
index f647a92..214046e 100644
--- a/funasr/models/scama/template.yaml
+++ b/funasr/models/scama/template.yaml
@@ -104,7 +104,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/seaco_paraformer/template.yaml b/funasr/models/seaco_paraformer/template.yaml
index ab2301a..fcaf524 100644
--- a/funasr/models/seaco_paraformer/template.yaml
+++ b/funasr/models/seaco_paraformer/template.yaml
@@ -134,7 +134,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/transformer/template.yaml b/funasr/models/transformer/template.yaml
index 87814dc..eb597f9 100644
--- a/funasr/models/transformer/template.yaml
+++ b/funasr/models/transformer/template.yaml
@@ -88,7 +88,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
diff --git a/funasr/models/uniasr/template.yaml b/funasr/models/uniasr/template.yaml
index 35c6b2e..e72a2d5 100644
--- a/funasr/models/uniasr/template.yaml
+++ b/funasr/models/uniasr/template.yaml
@@ -184,7 +184,7 @@
 dataset: AudioDataset
 dataset_conf:
     index_ds: IndexDSJsonl
-    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_sampler: BatchSampler
     batch_type: example # example or length
     batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,

--
Gitblit v1.9.1