| | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | # @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") |
| | | # class IndexDSJsonlRankSplit(torch.utils.data.Dataset): |
| | | # |
| | | # def __init__(self, path): |
| | | # super().__init__() |
| | | # |
| | | # contents = [] |
| | | # with open(path, encoding='utf-8') as fin: |
| | | # for line in fin: |
| | | # data = json.loads(line.strip()) |
| | | # if "text" in data: # for sft |
| | | # self.contents.append(data['text']) |
| | | # if "source" in data: # for speech lab pretrain |
| | | # prompt = data["prompt"] |
| | | # source = data["source"] |
| | | # target = data["target"] |
| | | # source_len = data["source_len"] |
| | | # target_len = data["target_len"] |
| | | # |
| | | # contents.append({"source": source, |
| | | # "prompt": prompt, |
| | | # "target": target, |
| | | # "source_len": source_len, |
| | | # "target_len": target_len, |
| | | # } |
| | | # ) |
| | | # |
| | | # self.contents = [] |
| | | # total_num = len(contents) |
| | | # try: |
| | | # rank = dist.get_rank() |
| | | # world_size = dist.get_world_size() |
| | | # except: |
| | | # rank = 0 |
| | | # world_size = 1 |
| | | # logging.info("distributed is not initialized, only single shard") |
| | | # num_per_rank = total_num // world_size |
| | | # |
| | | # # rank = 0 |
| | | # # import ipdb; ipdb.set_trace() |
| | | # self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | # |
| | | # logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents))) |
| | | # |
| | | # def __len__(self): |
| | | # return len(self.contents) |
| | | # |
| | | # def __getitem__(self, index): |
| | | # try: |
| | | # data = self.contents[index] |
| | | # except: |
| | | # print(index) |
| | | # return data |
| | | # |
| | | # def get_source_len(self, data_dict): |
| | | # return data_dict["source_len"] |
| | | # |
| | | # def get_target_len(self, data_dict): |
| | | # |
| | | # return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonl") |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankFull") |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") |
| | | class IndexDSJsonlRankFull(torch.utils.data.Dataset): |
| | | |
| | | |
| | | def __init__(self, path: str, **kwargs): |
| | | super().__init__() |
| | | self.max_source_length = kwargs.get("max_source_length", 2048) |
| | | self.min_source_length = kwargs.get("min_source_length", 0) |
| | | self.max_target_length = kwargs.get("max_target_length", 2048) |
| | | self.min_target_length = kwargs.get("min_target_length", 0) |
| | | self.max_token_length = kwargs.get("max_token_length", 2200) |
| | | |
| | | is_training = kwargs.get("is_training", True) |
| | | if not (path.endswith(".jsonl") or path.endswith(".json")): |
| | | # jsonl list file |
| | | data_split_num = kwargs.get("data_split_num", 1) |
| | | data_split_i = kwargs.get("data_split_i", 0) |
| | | |
| | | |
| | | if not is_training: |
| | | data_split_num = 1 |
| | | data_split_i = 0 |
| | | with open(path, encoding='utf-8') as fin: |
| | | with open(path, encoding="utf-8") as fin: |
| | | file_list_all = fin.readlines() |
| | | |
| | | num_per_slice = len(file_list_all) // data_split_num |
| | | file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice] |
| | | |
| | | num_per_slice = (len(file_list_all) - 1) // data_split_num + 1 # 16 |
| | | file_list = file_list_all[ |
| | | data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice |
| | | ] |
| | | logging.info( |
| | | f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}") |
| | | |
| | | f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}" |
| | | ) |
| | | |
| | | else: |
| | | file_list = [path] |
| | | |
| | | |
| | | total_num = len(file_list) |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | logging.info("distributed is not initialized, only single shard") |
| | | |
| | | if not kwargs.get("rank_split", False): |
| | | logging.info(f"Warning, rank_split disenabled, batch and shuffle data in global") |
| | | rank = 0 |
| | | world_size = 1 |
| | | |
| | | num_per_rank = total_num // world_size |
| | | if num_per_rank * world_size < total_num: |
| | | logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") |
| | | total_num_needed = num_per_rank * world_size |
| | | # total_num = len(file_list) |
| | | # try: |
| | | # rank = dist.get_rank() |
| | | # world_size = dist.get_world_size() |
| | | # except: |
| | | # rank = 0 |
| | | # world_size = 1 |
| | | # logging.info("distributed is not initialized, only single shard") |
| | | # |
| | | # if not kwargs.get("rank_split", False): |
| | | # logging.info(f"Warning, rank_split disenabled, batch and shuffle data in global") |
| | | # rank = 0 |
| | | # world_size = 1 |
| | | # |
| | | # num_per_rank = total_num // world_size |
| | | # if num_per_rank * world_size < total_num: |
| | | # logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") |
| | | # total_num_needed = num_per_rank * world_size |
| | | # |
| | | # extra_num = total_num_needed - total_num |
| | | # file_list_tmp = random.choices(file_list, k=extra_num) |
| | | # file_list += file_list_tmp |
| | | # logging.info(f"Warning, after random choices: {file_list}") |
| | | # |
| | | # file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | # |
| | | # logging.info( |
| | | # f"is_training: {is_training}, file_list_rank: {file_list_rank}") |
| | | |
| | | extra_num = total_num_needed - total_num |
| | | file_list_tmp = random.choices(file_list, k=extra_num) |
| | | file_list += file_list_tmp |
| | | logging.info(f"Warning, after random choices: {file_list}") |
| | | |
| | | file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | |
| | | logging.info( |
| | | f"is_training: {is_training}, file_list_rank: {file_list_rank}") |
| | | |
| | | # contents = [] |
| | | # for file_json in file_list_rank: |
| | | contents = [] |
| | | for file_json in file_list_rank: |
| | | with open(file_json.strip(), encoding='utf-8') as fin: |
| | | for file_json in file_list: |
| | | with open(file_json.strip(), encoding="utf-8") as fin: |
| | | for line in fin: |
| | | data = json.loads(line.strip()) |
| | | if "text" in data: # for sft |
| | | contents.append(data['text']) |
| | | contents.append(data["text"]) |
| | | if "source" in data: # for speech lab pretrain |
| | | prompt = data.get("prompt", "<ASR>") |
| | | source = data["source"].replace("/cpfs01", "/cpfs_speech/data") # only use in alibaba gpu group: .replace("/cpfs01", "/cpfs_speech/data") |
| | | source = data["source"].replace( |
| | | "/cpfs01", "/cpfs_speech/data" |
| | | ) # only use in alibaba gpu group: .replace("/cpfs01", "/cpfs_speech/data") |
| | | target = data["target"] |
| | | source_len = data.get("source_len", 1) |
| | | target_len = data.get("target_len", 0) |
| | | if "aishell" in source: |
| | | target = target.replace(" ", "") |
| | | if source_len < self.min_source_length or source_len > self.max_source_length: |
| | | if ( |
| | | source_len < self.min_source_length |
| | | or source_len > self.max_source_length |
| | | ): |
| | | continue |
| | | if target_len < self.min_target_length or target_len > self.max_target_length: |
| | | if ( |
| | | target_len < self.min_target_length |
| | | or target_len > self.max_target_length |
| | | ): |
| | | continue |
| | | contents_i = {"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | |
| | | if (source_len + target_len) > self.max_token_length: |
| | | continue |
| | | |
| | | contents_i = { |
| | | "source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | text_language = data.get("text_language", None) |
| | | if text_language is not None: |
| | | contents_i["text_language"] = text_language |
| | | if "emo_target" in data: |
| | | contents_i["emo_target"] = data["emo_target"] |
| | | if "event_target" in data: |
| | | contents_i["event_target"] = data["event_target"] |
| | | if "with_or_wo_itn" in data: |
| | | contents_i["with_or_wo_itn"] = data["with_or_wo_itn"] |
| | | # audio_language = data.get("audio_language", None) |
| | | # if audio_language is not None: |
| | | # contents_i["audio_language"] = audio_language |
| | | contents.append(contents_i) |
| | | |
| | | self.contents = contents |
| | | |
| | | logging.info( |
| | | "total_num of samplers: {}, {}".format(len(self.contents), path)) |
| | | |
| | | |
| | | logging.info("total_num of samplers: {}, {}".format(len(self.contents), path)) |
| | | |
| | | def __len__(self): |
| | | return len(self.contents) |
| | | |
| | | |
| | | def __getitem__(self, index): |
| | | |
| | | |
| | | data = self.contents[index] |
| | | |
| | | return data |
| | | |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict.get("source_len", 1) |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict.get("target_len", 0) |
| | | |
| | | # |
| | | # @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") |
| | | # class IndexDSJsonlRankSplit(torch.utils.data.Dataset): |
| | | # |
| | | # def __init__(self, path: str, **kwargs): |
| | | # super().__init__() |
| | | # logging.info("building IndexDS") |
| | | # self.max_source_length = kwargs.get("max_source_length", 2048) |
| | | # self.min_source_length = kwargs.get("min_source_length", 0) |
| | | # self.max_target_length = kwargs.get("max_target_length", 2048) |
| | | # self.min_target_length = kwargs.get("min_target_length", 0) |
| | | # |
| | | # data_split_num = kwargs.get("data_split_num", 1) |
| | | # data_split_i = kwargs.get("data_split_i", 0) |
| | | # if not kwargs.get("is_training", True): |
| | | # data_split_num = 1 |
| | | # data_split_i = 0 |
| | | # with open(path, encoding='utf-8') as fin: |
| | | # file_list_all = fin.readlines() |
| | | # |
| | | # num_per_slice = len(file_list_all) // data_split_num |
| | | # file_list = file_list_all[data_split_i * num_per_slice:(data_split_i + 1) * num_per_slice] |
| | | # logging.info(f"data_split_num: {data_split_num}, data_split_i: {data_split_i}, file_list: {file_list}, file_list_all: {file_list_all}") |
| | | # |
| | | # |
| | | # total_num = len(file_list) |
| | | # try: |
| | | # rank = dist.get_rank() |
| | | # world_size = dist.get_world_size() |
| | | # except: |
| | | # rank = 0 |
| | | # world_size = 1 |
| | | # logging.info("distributed is not initialized, only single shard") |
| | | # num_per_rank = total_num // world_size |
| | | # if num_per_rank * world_size < total_num: |
| | | # logging.info(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") |
| | | # |
| | | # file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | # |
| | | # contents = [] |
| | | # for file_json in file_list_rank: |
| | | # |
| | | # with open(file_json.strip(), encoding='utf-8') as fin: |
| | | # for line in fin: |
| | | # data = json.loads(line.strip()) |
| | | # if "text" in data: # for sft |
| | | # contents.append(data['text']) |
| | | # if "source" in data: # for speech lab pretrain |
| | | # prompt = data.get("prompt", "<ASR>") |
| | | # source = data["source"].replace("/cpfs01", "/cpfs_speech/data") |
| | | # target = data["target"] |
| | | # source_len = data.get("source_len", 1) |
| | | # target_len = data.get("target_len", 0) |
| | | # |
| | | # if source_len < self.min_source_length or source_len > self.max_source_length: |
| | | # continue |
| | | # if target_len < self.min_target_length or target_len > self.max_target_length: |
| | | # continue |
| | | # contents_i = {"source": source, |
| | | # "prompt": prompt, |
| | | # "target": target, |
| | | # "source_len": source_len, |
| | | # "target_len": target_len, |
| | | # } |
| | | # text_language = data.get("text_language", None) |
| | | # if text_language is not None: |
| | | # contents_i["text_language"] = text_language |
| | | # # audio_language = data.get("audio_language", None) |
| | | # # if audio_language is not None: |
| | | # # contents_i["audio_language"] = audio_language |
| | | # contents.append(contents_i) |
| | | # |
| | | # self.contents = contents |
| | | # |
| | | # logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}, file_list_rank: {file_list_rank}") |
| | | # |
| | | # def __len__(self): |
| | | # return len(self.contents) |
| | | # |
| | | # def __getitem__(self, index): |
| | | # try: |
| | | # data = self.contents[index] |
| | | # except: |
| | | # print(index) |
| | | # return data |
| | | # |
| | | # def get_source_len(self, data_dict): |
| | | # return data_dict.get("source_len", 1) |
| | | # |
| | | # def get_target_len(self, data_dict): |
| | | # |
| | | # return data_dict.get("target_len", 0) |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict.get("target_len", 0) |