游雁
2024-04-28 b76af7be8cd7428f19ec0ba9a7fd811148fbc358
funasr/datasets/sense_voice_datasets/datasets.py
@@ -51,6 +51,7 @@
        self.batch_size = kwargs.get("batch_size")
        self.batch_type = kwargs.get("batch_type")
        self.prompt_ids_len = 0
        self.retry = kwargs.get("retry", 5)
    def get_source_len(self, index):
        item = self.index_ds[index]
@@ -64,9 +65,21 @@
        return len(self.index_ds)
    def __getitem__(self, index):
        item = self.index_ds[index]
        # import pdb;
        # pdb.set_trace()
        output = None
        for idx in range(self.retry):
            if idx == 0:
                index_cur = index
            else:
                if index <= self.retry:
                    index_cur = index + idx
                else:
                    index_cur = torch.randint(0, index, ()).item()
            item = self.index_ds[index_cur]
        source = item["source"]
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
@@ -76,7 +89,7 @@
        )  # speech: [b, T, d]
        if speech_lengths > self.batch_size:
            return None
                continue
        speech = speech.permute(0, 2, 1)
        target = item["target"]
        if self.preprocessor_text:
@@ -93,7 +106,7 @@
        target_ids = self.tokenizer.encode(target, allowed_special="all")
        target_ids_len = len(target_ids) + 1  # [lid, text]
        if target_ids_len > 200:
            return None
                continue
        eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
@@ -109,7 +122,8 @@
        target_mask_lengths = len(target_mask)
        target_mask = torch.tensor(target_mask, dtype=torch.float32)
        target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32)
        return {
            output = {
            "speech": speech[0, :, :],
            "speech_lengths": speech_lengths,
            "text": text,
@@ -117,6 +131,9 @@
            "target_mask": target_mask,
            "target_mask_lengths": target_mask_lengths,
        }
            break
        return output
    def collator(self, samples: list = None):
        outputs = {}
@@ -129,13 +146,30 @@
                outputs[key].append(sample[key])
        if len(outputs) < 1:
            logging.info(f"ERROR: data is empty!")
            logging.error(f"ERROR: data is empty!")
            outputs = {
                "speech": torch.rand((10, 128), dtype=torch.float32),
                "speech_lengths": torch.tensor([10], dtype=torch.int32),
                "text": torch.tensor([58836], dtype=torch.int32),
                "text_lengths": torch.tensor([1], dtype=torch.int32),
                "target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]]),
                "speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
                "speech_lengths": torch.tensor(
                    [
                        10,
                    ],
                    dtype=torch.int32,
                )[:, None],
                "text": torch.tensor(
                    [
                        58836,
                    ],
                    dtype=torch.int32,
                )[None, :],
                "text_lengths": torch.tensor(
                    [
                        1,
                    ],
                    dtype=torch.int32,
                )[:, None],
                "target_mask": torch.tensor([[0] * (self.prompt_ids_len) + [1] * (1) + [1]])[
                    None, :
                ],
            }
            return outputs