From dd927baf28266c47acda6ae8d72b206526676201 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 17:10:29 +0800
Subject: [PATCH] batch
---
funasr/datasets/audio_datasets/update_jsonl.py | 64 ++++++++++++++++
funasr/datasets/audio_datasets/scp2len.py | 121 ++++++++++++++++++++++++++++++
2 files changed, 185 insertions(+), 0 deletions(-)
diff --git a/funasr/datasets/audio_datasets/scp2len.py b/funasr/datasets/audio_datasets/scp2len.py
new file mode 100644
index 0000000..5d742b1
--- /dev/null
+++ b/funasr/datasets/audio_datasets/scp2len.py
@@ -0,0 +1,121 @@
+import os
+import json
+import torch
+import logging
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+from tqdm import tqdm
+
+
+def gen_jsonl_from_wav_text_list(
+ path, data_type_list=("source",), jsonl_file_out: str = None, **kwargs
+):
+ try:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ except:
+ rank = 0
+ world_size = 1
+
+ cpu_cores = os.cpu_count() or 1
+ print(f"convert wav.scp text to jsonl, ncpu: {cpu_cores}")
+ if rank == 0:
+ json_dict = {}
+ # for data_type, data_file in zip(data_type_list, path):
+ data_type = data_type_list[0]
+ data_file = path
+ json_dict[data_type] = {}
+ with open(data_file, "r") as f:
+
+ data_file_lists = f.readlines()
+ print("")
+ lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
+ task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
+ # import pdb;pdb.set_trace()
+ if task_num > 1:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
+
+ futures = [
+ executor.submit(
+ parse_context_length,
+ data_file_lists[i * lines_for_each_th : (i + 1) * lines_for_each_th],
+ data_type,
+ i,
+ )
+ for i in range(task_num)
+ ]
+
+ for future in concurrent.futures.as_completed(futures):
+
+ json_dict[data_type].update(future.result())
+ else:
+ res = parse_context_length(data_file_lists, data_type)
+ json_dict[data_type].update(res)
+
+ with open(jsonl_file_out, "w") as f:
+ for key in json_dict[data_type_list[0]].keys():
+ jsonl_line = {"key": key}
+ for data_file in data_type_list:
+ jsonl_line.update(json_dict[data_file][key])
+ # jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
+ source_len = jsonl_line["source_len"]
+ jsonl_line = f"{key} {source_len}"
+ f.write(jsonl_line + "\n")
+ f.flush()
+ print(f"processed {len(json_dict[data_type_list[0]])} samples")
+
+ else:
+ pass
+
+ if world_size > 1:
+ dist.barrier()
+
+
+def parse_context_length(data_list: list, data_type: str, id=0):
+ pbar = tqdm(total=len(data_list), dynamic_ncols=True)
+ res = {}
+ for i, line in enumerate(data_list):
+ pbar.update(1)
+ pbar.set_description(f"cpu: {id}")
+ lines = line.strip().split(maxsplit=1)
+ key = lines[0]
+ line = lines[1] if len(lines) > 1 else ""
+ line = line.strip()
+ if os.path.exists(line):
+ waveform, _ = librosa.load(line, sr=16000)
+ sample_num = len(waveform)
+ context_len = int(sample_num / 16000 * 1000 / 10)
+ else:
+ context_len = len(line.split()) if " " in line else len(line)
+ res[key] = {data_type: line, f"{data_type}_len": context_len}
+ return res
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+
+ kwargs = OmegaConf.to_container(cfg, resolve=True)
+ print(kwargs)
+
+ scp_file_list = kwargs.get("scp_file_list", "/Users/zhifu/funasr1.0/data/list/train_wav.scp")
+ # if isinstance(scp_file_list, str):
+ # scp_file_list = eval(scp_file_list)
+ data_type_list = kwargs.get("data_type_list", ("source",))
+ jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/data/list/wav_len.txt")
+ gen_jsonl_from_wav_text_list(
+ scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
+ )
+
+
+"""
+python -m funasr.datasets.audio_datasets.scp2jsonl \
+++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
+"""
+
+if __name__ == "__main__":
+ main_hydra()
diff --git a/funasr/datasets/audio_datasets/update_jsonl.py b/funasr/datasets/audio_datasets/update_jsonl.py
new file mode 100644
index 0000000..3ce96ca
--- /dev/null
+++ b/funasr/datasets/audio_datasets/update_jsonl.py
@@ -0,0 +1,64 @@
+import os
+import json
+import torch
+import logging
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import concurrent.futures
+import librosa
+import torch.distributed as dist
+
+
+def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file):
+
+ wav_f = open(wav_scp_file, "w")
+ text_f = open(text_file, "w")
+ with open(jsonl_file, encoding="utf-8") as fin:
+ for line in fin:
+ data = json.loads(line.strip())
+
+ prompt = data.get("prompt", "<ASR>")
+ source = data[data_type_list[0]]
+ target = data[data_type_list[1]]
+ source_len = data.get("source_len", 1)
+ target_len = data.get("target_len", 0)
+ if "aishell" in source:
+ target = target.replace(" ", "")
+ key = data["key"]
+ wav_f.write(f"{key}\t{source}\n")
+ wav_f.flush()
+ text_f.write(f"{key}\t{target}\n")
+ text_f.flush()
+
+ wav_f.close()
+ text_f.close()
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+
+ kwargs = OmegaConf.to_container(cfg, resolve=True)
+ print(kwargs)
+
+ scp_file_list = kwargs.get(
+ "scp_file_list",
+ ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
+ )
+ if isinstance(scp_file_list, str):
+ scp_file_list = eval(scp_file_list)
+ data_type_list = kwargs.get("data_type_list", ("source", "target"))
+ jsonl_file = kwargs.get(
+ "jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
+ )
+ gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
+
+
+"""
+python -m funasr.datasets.audio_datasets.json2scp \
+++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
+++data_type_list='["source", "target"]' \
+++jsonl_file_in=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
+"""
+
+if __name__ == "__main__":
+ main_hydra()
--
Gitblit v1.9.1