kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import json
import torch
import logging
 
import librosa
import random
import torch.distributed as dist
 
from funasr.register import tables
 
 
@tables.register("index_ds_classes", "OpenAIIndexDSJsonl")
class OpenAIIndexDSJsonl(torch.utils.data.Dataset):  # torch.utils.data.Dataset
 
    def __init__(self, path: str, **kwargs):
        super().__init__()
 
        self.max_source_length = kwargs.get("max_source_length", 3000)
        self.min_source_length = kwargs.get("min_source_length", 0)
        self.max_target_length = kwargs.get("max_target_length", 2048)
        self.min_target_length = kwargs.get("min_target_length", 0)
        self.max_token_length = kwargs.get("max_token_length", 2200)
 
        is_training = kwargs.get("is_training", True)
        if not (path.endswith(".jsonl") or path.endswith(".json")):
            # jsonl list file
            data_split_num = kwargs.get("data_split_num", 1)
            data_split_i = kwargs.get("data_split_i", 0)
 
            if not is_training:
                data_split_num = 1
                data_split_i = 0
            with open(path, encoding="utf-8") as fin:
                file_list_all = fin.readlines()
 
                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1  # 16
                file_list = file_list_all[
                    data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
                ]
                logging.info(
                    f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}"
                )
 
        else:
            file_list = [path]
 
        contents = []
        for file_json in file_list:
            with open(file_json.strip(), encoding="utf-8") as fin:
                for line in fin:
                    data_dict = json.loads(line.strip())
                    data = data_dict["messages"]
                    speech_length = data_dict.get("speech_length", -1) // 8
                    text_length = data_dict.get("text_length", 0)
                    if speech_length > self.max_source_length:
                        logging.info(
                            "speech_length: {speech_length} > {self.max_source_length}, drop it"
                        )
                        continue
                    if text_length > self.max_target_length:
                        continue
 
                    self.max_target_length = kwargs.get("max_target_length", 2048)
 
                    system, user, assistant = [], [], []
                    for i, item in enumerate(data):
                        role = item["role"]
                        content = item["content"]
                        if role == "system":
                            system.append(content)
                        elif role == "user":
                            user.append(content)
                        elif role == "assistant":
                            assistant.append(content)
 
                    system = system * len(user)
 
                    contents_i = {
                        "system": system,
                        "user": user,
                        "assistant": assistant,
                        "source_len": speech_length + text_length,
                    }
                    contents.append(contents_i)
 
        self.contents = contents
 
        logging.info("total_num of samplers: {}, {}".format(len(self.contents), path))
 
    def __len__(self):
        return len(self.contents)
 
    def __getitem__(self, index):
 
        data = self.contents[index]
 
        return data
 
    def get_source_len(self, data_dict):
        source_len = data_dict.get("source_len", -1)
        if source_len < 0:
            source_len = len(data_dict["system"]) + len(data_dict["user"])
        return source_len
 
    def get_target_len(self, data_dict):
 
        return 0
 
 
if __name__ == "__main__":
    index_ds = OpenAIIndexDSJsonl(
        path="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl"
    )
    print(index_ds.contents)
    pass