From fdafd3f6bc2f04d16e7cab5afcdb1257e87a8a78 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 17 十二月 2024 11:15:53 +0800
Subject: [PATCH] emotion2vec

---
 funasr/datasets/audio_datasets/scp2jsonl.py |   27 +++++++++++++++++++--------
 1 files changed, 19 insertions(+), 8 deletions(-)

diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index f6ceb69..48c64d2 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -7,6 +7,7 @@
 import concurrent.futures
 import librosa
 import torch.distributed as dist
+from tqdm import tqdm
 
 
 def gen_jsonl_from_wav_text_list(
@@ -41,6 +42,7 @@
                                     i * lines_for_each_th : (i + 1) * lines_for_each_th
                                 ],
                                 data_type,
+                                i,
                             )
                             for i in range(task_num)
                         ]
@@ -56,7 +58,8 @@
             for key in json_dict[data_type_list[0]].keys():
                 jsonl_line = {"key": key}
                 for data_file in data_type_list:
-                    jsonl_line.update(json_dict[data_file][key])
+                    if key in json_dict[data_file]:
+                        jsonl_line.update(json_dict[data_file][key])
                 jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
                 f.write(jsonl_line + "\n")
                 f.flush()
@@ -69,16 +72,24 @@
         dist.barrier()
 
 
-def parse_context_length(data_list: list, data_type: str):
-
+def parse_context_length(data_list: list, data_type: str, id=0):
+    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
     res = {}
     for i, line in enumerate(data_list):
-        key, line = line.strip().split(maxsplit=1)
+        pbar.update(1)
+        pbar.set_description(f"cpu: {id}")
+        lines = line.strip().split(maxsplit=1)
+        key = lines[0]
+        line = lines[1] if len(lines) > 1 else ""
         line = line.strip()
-        if os.path.exists(line):
-            waveform, _ = librosa.load(line, sr=16000)
-            sample_num = len(waveform)
-            context_len = int(sample_num / 16000 * 1000 / 10)
+        if data_type == "source":
+            if os.path.exists(line):
+                waveform, _ = librosa.load(line, sr=16000)
+                sample_num = len(waveform)
+                context_len = int(sample_num * 1000 / 16000 / 10)
+            else:
+                print("source file not found: {}".format(line))
+                continue
         else:
             context_len = len(line.split()) if " " in line else len(line)
         res[key] = {data_type: line, f"{data_type}_len": context_len}

--
Gitblit v1.9.1