| | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonl") |
| | | class IndexDSJsonl(torch.utils.data.Dataset): |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") |
| | | class IndexDSJsonlRankSplit(torch.utils.data.Dataset): |
| | | |
| | | def __init__(self, path): |
| | | super().__init__() |
| | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonl") |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankFull") |
| | | class IndexDSJsonlRankFull(torch.utils.data.Dataset): |
| | | |
| | | def __init__(self, path): |
| | | super().__init__() |
| | | |
| | | contents = [] |
| | | with open(path, encoding='utf-8') as fin: |
| | | for line in fin: |
| | | data = json.loads(line.strip()) |
| | | if "text" in data: # for sft |
| | | self.contents.append(data['text']) |
| | | if "source" in data: # for speech lab pretrain |
| | | prompt = data.get("prompt", "<ASR>") |
| | | source = data["source"] |
| | | target = data["target"] |
| | | source_len = data.get("source_len", 1) |
| | | target_len = data.get("target_len", 0) |
| | | |
| | | contents.append({"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | ) |
| | | |
| | | self.contents = contents |
| | | |
| | | logging.info( |
| | | "total_num of samplers across ranks: {}".format(len(self.contents))) |
| | | |
| | | def __len__(self): |
| | | return len(self.contents) |
| | | |
| | | def __getitem__(self, index): |
| | | try: |
| | | data = self.contents[index] |
| | | except: |
| | | print(index) |
| | | return data |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict.get("source_len", 1) |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict.get("target_len", 0) |