游雁
2024-04-29 03040b04e24acc7cd024a258b259841efa5adead
batch
2个文件已修改
94 ■■■■■ 已修改文件
funasr/datasets/audio_datasets/scp2jsonl.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/update_jsonl.py 93 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/scp2jsonl.py
@@ -29,7 +29,6 @@
            with open(data_file, "r") as f:
                data_file_lists = f.readlines()
                print("")
                lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                # import pdb;pdb.set_trace()
funasr/datasets/audio_datasets/update_jsonl.py
@@ -7,31 +7,68 @@
import concurrent.futures
import librosa
import torch.distributed as dist
import threading
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
def gen_scp_from_jsonl(jsonl_file, data_type_list, wav_scp_file, text_file):
    wav_f = open(wav_scp_file, "w")
    text_f = open(text_file, "w")
def gen_scp_from_jsonl(jsonl_file, jsonl_file_out, ncpu):
    jsonl_file_out_f = open(jsonl_file_out, "w")
    with open(jsonl_file, encoding="utf-8") as fin:
        for line in fin:
            data = json.loads(line.strip())
        lines = fin.readlines()
            prompt = data.get("prompt", "<ASR>")
            source = data[data_type_list[0]]
            target = data[data_type_list[1]]
            source_len = data.get("source_len", 1)
            target_len = data.get("target_len", 0)
            if "aishell" in source:
                target = target.replace(" ", "")
            key = data["key"]
            wav_f.write(f"{key}\t{source}\n")
            wav_f.flush()
            text_f.write(f"{key}\t{target}\n")
            text_f.flush()
        num_total = len(lines)
        if ncpu > 1:
            # 使用ThreadPoolExecutor限制并发线程数
            with ThreadPoolExecutor(max_workers=ncpu) as executor:
                # 提交任务到线程池
                futures = {executor.submit(update_data, lines, i) for i in tqdm(range(num_total))}
    wav_f.close()
    text_f.close()
                # 等待所有任务完成,这会阻塞直到所有提交的任务完成
                for future in concurrent.futures.as_completed(futures):
                    # 这里可以添加额外的逻辑来处理完成的任务,但在这个例子中我们只是等待
                    pass
        else:
            for i in range(num_total):
                update_data(lines, i)
        print("All audio durations have been processed.")
        for line in lines:
            jsonl_file_out_f.write(line)
            jsonl_file_out_f.flush()
    jsonl_file_out_f.close()
def update_data(lines, i):
    line = lines[i]
    data = json.loads(line.strip())
    wav_path = data["source"].replace("/cpfs01", "/cpfs_speech/data")
    waveform, _ = librosa.load(wav_path, sr=16000)
    sample_num = len(waveform)
    source_len = int(sample_num / 16000 * 1000 / 10)
    source_len_old = data["source_len"]
    if source_len_old != source_len:
        print(f"wav: {wav_path}, old: {source_len_old}, new: {source_len}")
    data["source_len"] = source_len
    jsonl_line = json.dumps(data, ensure_ascii=False)
    lines[i] = jsonl_line
def update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu=1):
    os.makedirs(jsonl_file_out_dir, exist_ok=True)
    with open(jsonl_file_list_in, "r") as f:
        data_file_lists = f.readlines()
        for i, jsonl in enumerate(data_file_lists):
            filename_with_extension = os.path.basename(jsonl.strip())
            jsonl_file_out = os.path.join(jsonl_file_out_dir, filename_with_extension)
            print(f"{i}/{len(data_file_lists)}, jsonl: {jsonl}, {jsonl_file_out}")
            gen_scp_from_jsonl(jsonl.strip(), jsonl_file_out, ncpu)
@hydra.main(config_name=None, version_base=None)
@@ -40,17 +77,13 @@
    kwargs = OmegaConf.to_container(cfg, resolve=True)
    print(kwargs)
    scp_file_list = kwargs.get(
        "scp_file_list",
        ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"),
    jsonl_file_list_in = kwargs.get(
        "jsonl_file_list_in", "/Users/zhifu/funasr1.0/data/list/data_jsonl.list"
    )
    if isinstance(scp_file_list, str):
        scp_file_list = eval(scp_file_list)
    data_type_list = kwargs.get("data_type_list", ("source", "target"))
    jsonl_file = kwargs.get(
        "jsonl_file_in", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl"
    )
    gen_scp_from_jsonl(jsonl_file, data_type_list, *scp_file_list)
    jsonl_file_out_dir = kwargs.get("jsonl_file_out_dir", "/Users/zhifu/funasr1.0/data_tmp")
    ncpu = kwargs.get("ncpu", 1)
    update_wav_len(jsonl_file_list_in, jsonl_file_out_dir, ncpu)
    # gen_scp_from_jsonl(jsonl_file_list_in, jsonl_file_out_dir)
"""