游雁
2024-06-17 0033151b6286e9d17a9b91567ba649ff14a89464
funasr/datasets/openai_datasets/datasets.py
@@ -300,8 +300,9 @@
        return len(self.index_ds)
    def __getitem__(self, index):
        # import pdb;
        # pdb.set_trace()
        import pdb
        pdb.set_trace()
        output = None
@@ -318,7 +319,15 @@
            user = item["user"]
            assistant = item["assistant"]
            input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
            input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
                [],
                [],
                [],
                [],
                [],
                [],
                [],
            )
            for i, (system_prompt, user_prompt, target_out) in enumerate(
                zip(system, user, assistant)
@@ -336,7 +345,8 @@
                source_ids = []
                fbank_i = []
                fbank_mask_i = []
                fbank_beg_i = []
                fake_token_len_i = 0
                fbank_beg_i = -1
                fbank_lens_i = []
                for k, sub_str in enumerate(splits):
                    if not sub_str.startswith("<|startofspeech|>"):
@@ -369,14 +379,17 @@
                            olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
                            olens = 1 + (olens - 3 + 2 * 1) // 2
                            sub_token_len = (olens - 1) // 2 + 1
                            sub_token = [0] * sub_token_len
                            fbank_beg_i = [len(source_ids)]
                            source_ids += sub_token
                            fbank_mask_i += [1] * len(sub_token)
                            fake_token_len_i = (olens - 1) // 2 + 1
                            fake_token = [0] * fake_token_len_i
                            fbank_beg_i = len(source_ids)
                            source_ids += fake_token
                            fbank_mask_i += [1] * len(fake_token)
                if badcase_flag:
                    continue
                fbank_beg += [fbank_beg_i + len(input_ids)]
                fake_token_len += [fake_token_len_i]
                source_mask = [-100] * len(source_ids)
                target_out = f"{target_out}<|im_end|>"
                target_ids = self.tokenizer.encode(target_out)
@@ -384,9 +397,6 @@
                labels += source_mask + target_ids
                fbank.append(speech[0, :, :])
                fbank_mask += fbank_mask_i
                if len(fbank_beg_i) < 1:
                    fbank_beg_i = [-1]
                fbank_beg += fbank_beg_i
            if len(input_ids) > self.max_token_length:
                logging.info(
@@ -403,12 +413,14 @@
            fbank_lens = speech_lengths
            fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
            fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
            fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
            output = {
                "speech": fbank,
                "speech_lengths": fbank_lens,
                "fbank_mask": fbank_mask,
                "fbank_beg": fbank_beg,
                "fake_token_len": fake_token_len,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels_ids": labels,