游雁
2024-06-19 45d7aa9004763684fb748ee17942ecba81042201
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import json
import torch
import logging
 
import librosa
import random
import torch.distributed as dist
 
from funasr.register import tables
 
 
@tables.register("index_ds_classes", "OpenAIIndexDSJsonl")
class OpenAIIndexDSJsonl(torch.utils.data.Dataset):  # torch.utils.data.Dataset
 
    def __init__(self, path: str, **kwargs):
        super().__init__()
 
        is_training = kwargs.get("is_training", True)
        if not (path.endswith(".jsonl") or path.endswith(".json")):
            # jsonl list file
            data_split_num = kwargs.get("data_split_num", 1)
            data_split_i = kwargs.get("data_split_i", 0)
 
            if not is_training:
                data_split_num = 1
                data_split_i = 0
            with open(path, encoding="utf-8") as fin:
                file_list_all = fin.readlines()
 
                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1  # 16
                file_list = file_list_all[
                    data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
                ]
                logging.info(
                    f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}"
                )
 
        else:
            file_list = [path]
 
        contents = []
        for file_json in file_list:
            with open(file_json.strip(), encoding="utf-8") as fin:
                for line in fin:
                    data_dict = json.loads(line.strip())
                    data = data_dict["messages"]
                    speech_length = data_dict.get("speech_length", -1) // 8
                    text_length = data_dict.get("text_length", 0)
 
                    system, user, assistant = [], [], []
                    for i, item in enumerate(data):
                        role = item["role"]
                        content = item["content"]
                        if role == "system":
                            system.append(content)
                        elif role == "user":
                            user.append(content)
                        elif role == "assistant":
                            assistant.append(content)
 
                    system = system * len(user)
 
                    contents_i = {
                        "system": system,
                        "user": user,
                        "assistant": assistant,
                        "source_len": speech_length + text_length,
                    }
                    contents.append(contents_i)
 
        self.contents = contents
 
        logging.info("total_num of samplers: {}, {}".format(len(self.contents), path))
 
    def __len__(self):
        return len(self.contents)
 
    def __getitem__(self, index):
 
        data = self.contents[index]
 
        return data
 
    def get_source_len(self, data_dict):
        source_len = data_dict.get("source_len", -1)
        if source_len < 0:
            source_len = len(data_dict["system"]) + len(data_dict["user"])
        return source_len
 
    def get_target_len(self, data_dict):
 
        return 0
 
 
if __name__ == "__main__":
    index_ds = OpenAIIndexDSJsonl(
        path="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl"
    )
    print(index_ds.contents)
    pass