游雁
2024-02-19 1448e021accfdb03a381651cb5a8be6d1a6e8adf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import json
import torch
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import concurrent.futures
import librosa
import torch.distributed as dist
 
 
 
def gen_jsonl_from_wav_text_list(path, data_type_list=("source", "target"), jsonl_file_out:str=None, **kwargs):
    try:
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    except:
        rank = 0
        world_size = 1
 
    cpu_cores = os.cpu_count() or 1
    
    if rank == 0:
        json_dict = {}
        for data_type, data_file in zip(data_type_list, path):
            json_dict[data_type] = {}
            with open(data_file, "r") as f:
                
                data_file_lists = f.readlines()
                lines_for_each_th = (len(data_file_lists)-1)//cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_cores) as executor:
 
                    futures = [executor.submit(parse_context_length, data_file_lists[i*lines_for_each_th:(i+1)*lines_for_each_th], data_type) for i in range(task_num)]
    
                    for future in concurrent.futures.as_completed(futures):
                        
                        json_dict[data_type].update(future.result())
            # print(json_dict)
        
        with open(jsonl_file_out, "w") as f:
            for key in json_dict[data_type_list[0]].keys():
                jsonl_line = {"key": key}
                for data_file in data_type_list:
                    jsonl_line.update(json_dict[data_file][key])
                jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
                f.write(jsonl_line+"\n")
                f.flush()
                
    else:
        pass
        
    if world_size > 1:
        dist.barrier()
    
    
def parse_context_length(data_list: list, data_type: str):
    
    res = {}
    for i, line in enumerate(data_list):
        key, line = line.strip().split(maxsplit=1)
        line = line.strip()
        if os.path.exists(line):
            waveform, _ = librosa.load(line, sr=16000)
            sample_num = len(waveform)
            context_len = int(sample_num//16000*1000/10)
        else:
            context_len = len(line)
        res[key] = {data_type: line, f"{data_type}_len": context_len}
    return res
    
 
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
    """
    python funasr/datasets/audio_datasets/scp2jsonl.py \
    ++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
    ++data_type_list='["source", "target"]' \
    ++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
 
    """
    
    kwargs = OmegaConf.to_container(cfg, resolve=True)
 
    scp_file_list = kwargs.get("scp_file_list", ("/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"))
    data_type_list = kwargs.get("data_type_list", ("source", "target"))
    jsonl_file_out = kwargs.get("jsonl_file_out", "/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl")
    gen_jsonl_from_wav_text_list(scp_file_list, data_type_list=data_type_list, jsonl_file_out=jsonl_file_out)
    
 
if __name__ == "__main__":
    main_hydra()