From 24aea85b5bc3f354d683201fa9e37968f3f1638f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 14:01:45 +0800
Subject: [PATCH] trainer

---
 funasr/datasets/audio_datasets/scp2jsonl.py |   19 ++++++++++++-------
 1 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index e09a84a..00754dd 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -29,15 +29,19 @@
                 data_file_lists = f.readlines()
                 lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
                 task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
-                with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
-
-                    futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
+                # import pdb;pdb.set_trace()
+                if task_num > 1:
+                    with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
     
-                    for future in concurrent.futures.as_completed(futures):
-                        
-                        json_dict[data_type].update(future.result())
-            # print(json_dict)
+                        futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
         
+                        for future in concurrent.futures.as_completed(futures):
+                            
+                            json_dict[data_type].update(future.result())
+                else:
+                    res = parse_context_length(data_file_lists, data_type)
+                    json_dict[data_type].update(res)
+
         with open(jsonl_file_out, "w") as f:
             for key in json_dict[data_type_list[0]].keys():
                 jsonl_line = {"key": key}
@@ -46,6 +50,7 @@
                 jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
                 f.write(jsonl_line+"\n")
                 f.flush()
+        print(f"processed {len(json_dict[data_type_list[0]])} samples")
                 
     else:
         pass

--
Gitblit v1.9.1