From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/datasets/audio_datasets/scp2jsonl.py |   91 +++++++++++++++++++++++++++++----------------
 1 files changed, 59 insertions(+), 32 deletions(-)

diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index e09a84a..f4c9d74 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -7,10 +7,12 @@
 import concurrent.futures
 import librosa
 import torch.distributed as dist
+from tqdm import tqdm
 
 
-
-def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
+def gen_jsonl_from_wav_text_list(
+    path, data_type_list=("source", "target"), jsonl_file_out: str = None, **kwargs
+):
     try:
         rank = dist.get_rank()
         world_size = dist.get_world_size()
@@ -25,63 +27,90 @@
         for data_type, data_file in zip(data_type_list, path):
             json_dict[data_type] = {}
             with open(data_file, "r") as f:
-                
-                data_file_lists = f.readlines()
-                lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
-                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
-                with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
 
-                    futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
-    
-                    for future in concurrent.futures.as_completed(futures):
-                        
-                        json_dict[data_type].update(future.result())
-            # print(json_dict)
-        
+                data_file_lists = f.readlines()
+                lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
+                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
+                # import pdb;pdb.set_trace()
+                if task_num > 1:
+                    with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
+
+                        futures = [
+                            executor.submit(
+                                parse_context_length,
+                                data_file_lists[
+                                    i * lines_for_each_th : (i + 1) * lines_for_each_th
+                                ],
+                                data_type,
+                                i,
+                            )
+                            for i in range(task_num)
+                        ]
+
+                        for future in concurrent.futures.as_completed(futures):
+
+                            json_dict[data_type].update(future.result())
+                else:
+                    res = parse_context_length(data_file_lists, data_type)
+                    json_dict[data_type].update(res)
+
         with open(jsonl_file_out, "w") as f:
             for key in json_dict[data_type_list[0]].keys():
                 jsonl_line = {"key": key}
                 for data_file in data_type_list:
                     jsonl_line.update(json_dict[data_file][key])
                 jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
-                f.write(jsonl_line+"\n")
+                f.write(jsonl_line + "\n")
                 f.flush()
-                
+        print(f"processed {len(json_dict[data_type_list[0]])} samples")
+
     else:
         pass
-        
+
     if world_size > 1:
         dist.barrier()
-    
-    
-def parse_context_length(data_list: list, data_type: str):
-    
+
+
+def parse_context_length(data_list: list, data_type: str, id=0):
+    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
     res = {}
     for i, line in enumerate(data_list):
-        key, line = line.strip().split(maxsplit=1)
+        pbar.update(1)
+        pbar.set_description(f"cpu: {id}")
+        lines = line.strip().split(maxsplit=1)
+        key = lines[0]
+        line = lines[1] if len(lines) > 1 else ""
         line = line.strip()
         if os.path.exists(line):
             waveform, _ = librosa.load(line, sr=16000)
             sample_num = len(waveform)
-            context_len = int(sample_num//16000*1000/10)
+            context_len = int(sample_num / 16000 * 1000 / 10)
         else:
             context_len = len(line.split()) if " " in line else len(line)
         res[key] = {data_type: line, f"{data_type}_len": context_len}
     return res
-    
+
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(cfg: DictConfig):
- 
-    kwargs = OmegaConf.to_container(cfg, resolve=True)
 
-    scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
+    kwargs = OmegaConf.to_container(cfg, resolve=True)
+    print(kwargs)
+
+    scp_file_list = kwargs.get(
+        "scp_file_list",
+        ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
+    )
     if isinstance(scp_file_list, str):
         scp_file_list = eval(scp_file_list)
     data_type_list = kwargs.get("data_type_list", ("source", "target"))
-    jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
-    gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
-    
+    jsonl_file_out = kwargs.get(
+        "jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
+    )
+    gen_jsonl_from_wav_text_list(
+        scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out
+    )
+
 
 """
 python -m funasr.datasets.audio_datasets.scp2jsonl \
@@ -92,5 +121,3 @@
 
 if __name__ == "__main__":
     main_hydra()
-
-    
\ No newline at end of file

--
Gitblit v1.9.1