From 03040b04e24acc7cd024a258b259841efa5adead Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 19:14:13 +0800
Subject: [PATCH] batch

---
 funasr/datasets/audio_datasets/update_jsonl.py |   93 +++++++++++++++++++++++++++++++---------------
 funasr/datasets/audio_datasets/scp2jsonl.py    |    1 
 2 files changed, 63 insertions(+), 31 deletions(-)

diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index f167173..f4c9d74 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -29,7 +29,6 @@
             with open(data_file, "r") as f:
 
                 data_file_lists = f.readlines()
-                print("")
                 lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                 task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                 # import pdb;pdb.set_trace()
diff --git a/funasr/datasets/audio_datasets/update_jsonl.py b/funasr/datasets/audio_datasets/update_jsonl.py
index 3ce96ca..6fe377c 100644
--- a/funasr/datasets/audio_datasets/update_jsonl.py
+++ b/funasr/datasets/audio_datasets/update_jsonl.py
@@ -7,31 +7,68 @@
 import concurrent.futures
 import librosa
 import torch.distributed as dist
+import threading
+from tqdm import tqdm
+from concurrent.futures import ThreadPoolExecutor
 
 
-def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file):
-
-    wav_f = open(wav_scp_file, "w")
-    text_f = open(text_file, "w")
+def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu):
+    jsonl_file_out_f = open(jsonl_file_out, "w")
     with open(jsonl_file, encoding="utf-8") as fin:
-        for line in fin:
-            data = json.loads(line.strip())
+        lines = fin.readlines()
 
-            prompt = data.get("prompt", "<ASR>")
-            source = data[data_type_list[0]]
-            target = data[data_type_list[1]]
-            source_len = data.get("source_len", 1)
-            target_len = data.get("target_len", 0)
-            if "aishell" in source:
-                target = target.replace(" ", "")
-            key = data["key"]
-            wav_f.write(f"{key}\t{source}\n")
-            wav_f.flush()
-            text_f.write(f"{key}\t{target}\n")
-            text_f.flush()
+        num_total = len(lines)
+        if ncpu > 1:
+            # 浣跨敤ThreadPoolExecutor闄愬埗骞跺彂绾跨▼鏁�
+            with ThreadPoolExecutor(max_workers=ncpu) as executor:
+                # 鎻愪氦浠诲姟鍒扮嚎绋嬫睜
+                futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))}
 
-    wav_f.close()
-    text_f.close()
+                # 绛夊緟鎵�鏈変换鍔″畬鎴愶紝杩欎細闃诲鐩村埌鎵�鏈夋彁浜ょ殑浠诲姟瀹屾垚
+                for future in concurrent.futures.as_completed(futures):
+                    # 杩欓噷鍙互娣诲姞棰濆鐨勯�昏緫鏉ュ鐞嗗畬鎴愮殑浠诲姟锛屼絾鍦ㄨ繖涓緥瀛愪腑鎴戜滑鍙槸绛夊緟
+                    pass
+        else:
+            for i in range(num_total):
+                update_data(lines, i)
+        print("All audio durations have been processed.")
+
+        for line in lines:
+
+            jsonl_file_out_f.write(line)
+            jsonl_file_out_f.flush()
+
+    jsonl_file_out_f.close()
+
+
+def update_data(lines, i):
+    line = lines[i]
+    data = json.loads(line.strip())
+
+    wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
+    waveform, _ = librosa.load(wav_path, sr=16000)
+    sample_num = len(waveform)
+    source_len = int(sample_num / 16000 * 1000 / 10)
+    source_len_old = data["source_len"]
+    if source_len_old != source_len:
+        print(f"wav: {wav_path}, old: {source_len_old}, new: {source_len}")
+    data["source_len"] = source_len
+    jsonl_line = json.dumps(data, ensure_ascii=False)
+    lines[i] = jsonl_line
+
+
+def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
+
+    os.makedirs(jsonl_file_out_dir, exist_ok=True)
+    with open(jsonl_file_list_in, "r") as f:
+        data_file_lists = f.readlines()
+
+        for i, jsonl in enumerate(data_file_lists):
+            filename_with_extension = os.path.basename(jsonl.strip())
+            jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension)
+            print(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}")
+
+            gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu)
 
 
 @hydra.main(config_name=None, version_base=None)
@@ -40,17 +77,13 @@
     kwargs = OmegaConf.to_container(cfg, resolve=True)
     print(kwargs)
 
-    scp_file_list = kwargs.get(
-        "scp_file_list",
-        ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
+    jsonl_file_list_in = kwargs.get(
+        "jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list"
     )
-    if isinstance(scp_file_list, str):
-        scp_file_list = eval(scp_file_list)
-    data_type_list = kwargs.get("data_type_list", ("source", "target"))
-    jsonl_file = kwargs.get(
-        "jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
-    )
-    gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
+    jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp")
+    ncpu = kwargs.get("ncpu", 1)
+    update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)
+    # gen_scp_from_jsonl(jsonl_file_list_in, jsonl_file_out_dir)
 
 
 """

--
Gitblit v1.9.1