zhifu gao
2024-06-24 abb33d6b2097e5b0643326bc1b376a63cdc2f967
funasr/datasets/sense_voice_datasets/datasets.py
@@ -229,3 +229,202 @@
            outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max]
        return outputs
@tables.register("dataset_classes", "SenseVoiceCTCDataset")
class SenseVoiceCTCDataset(torch.utils.data.Dataset):
    """
    SenseVoiceCTCDataset
    """
    def __init__(
        self,
        path,
        index_ds: str = None,
        frontend=None,
        tokenizer=None,
        int_pad_value: int = -1,
        float_pad_value: float = 0.0,
        **kwargs,
    ):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(
                **kwargs.get("preprocessor_speech_conf")
            )
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.int_pad_value = int_pad_value
        self.float_pad_value = float_pad_value
        self.sos = kwargs.get("sos", "<|startoftranscript|>")
        self.eos = kwargs.get("eos", "<|endoftext|>")
        self.batch_size = kwargs.get("batch_size")
        self.batch_type = kwargs.get("batch_type")
        self.prompt_ids_len = 0
        self.retry = kwargs.get("retry", 5)
        self.permute = False
        from funasr.frontends.whisper_frontend import WhisperFrontend
        if isinstance(self.frontend, WhisperFrontend):
            self.permute = True
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
    def get_target_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
    def __len__(self):
        return len(self.index_ds)
    def __getitem__(self, index):
        output = None
        for idx in range(self.retry):
            if idx == 0:
                index_cur = index
            else:
                index_cur = torch.randint(0, len(self.index_ds), ()).item()
            item = self.index_ds[index_cur]
            source = item["source"]
            try:
                data_src = load_audio_text_image_video(source, fs=self.fs)
            except Exception as e:
                logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
                continue
            if self.preprocessor_speech:
                data_src = self.preprocessor_speech(data_src, fs=self.fs)
            speech, speech_lengths = extract_fbank(
                data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
            )  # speech: [b, T, d]
            if speech_lengths > self.batch_size:
                continue
            if self.permute:
                speech = speech.permute(0, 2, 1)
            asr_target = item["target"]
            if self.preprocessor_text:
                asr_target = self.preprocessor_text(asr_target)
            emo_target = item["emo_target"]
            event_target = item["event_target"]
            text_language = item.get("text_language", "<|zh|>")
            punc_itn_bottom = item.get("with_or_wo_itn", "<|SPECIAL_TOKEN_13|>")
            target_ids = self.tokenizer.encode(asr_target, allowed_special="all")
            target_ids_len = len(target_ids)  # [text]
            if target_ids_len > 200:
                continue
            lid_ids = self.tokenizer.encode(text_language, allowed_special="all")
            emo_ids = self.tokenizer.encode(emo_target, allowed_special="all")
            event_ids = self.tokenizer.encode(event_target, allowed_special="all")
            punc_itn_bottom_ids = self.tokenizer.encode(punc_itn_bottom, allowed_special="all")
            ids = lid_ids + emo_ids + event_ids + punc_itn_bottom_ids + target_ids # [lid, emo, lid, itn, text]
            ids_lengths = len(ids)
            text = torch.tensor(ids, dtype=torch.int64)
            text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
            output = {
                "speech": speech[0, :, :],
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
            }
            break
        return output
    def collator(self, samples: list = None):
        outputs = {}
        for sample in samples:
            if sample is None:
                continue
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        if len(outputs) < 1:
            logging.error(f"ERROR: data is empty!")
            outputs = {
                "speech": torch.rand((10, 128), dtype=torch.float32)[None, :, :],
                "speech_lengths": torch.tensor(
                    [
                        10,
                    ],
                    dtype=torch.int32,
                )[:, None],
                "text": torch.tensor(
                    [
                        58836,
                    ],
                    dtype=torch.int32,
                )[None, :],
                "text_lengths": torch.tensor(
                    [
                        1,
                    ],
                    dtype=torch.int32,
                )[:, None],
            }
            return outputs
        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )
        if self.batch_type != "example":
            for i in range(10):
                outputs = self._filter_badcase(outputs, i=i)
        return outputs
    def _filter_badcase(self, outputs, i=0):
        b, t, _ = outputs["speech"].shape
        if b * t > self.batch_size * 1.25:
            beg = torch.randint(0, 2, ()).item()
            if b < 2:
                beg = 0
            logging.info(
                f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}"
            )
            for key, data_list in outputs.items():
                outputs[key] = outputs[key][beg : beg + b : 2]
            speech_lengths_max = outputs["speech_lengths"].max().item()
            outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
            text_lengths_max = outputs["text_lengths"].max().item()
            outputs["text"] = outputs["text"][:, :text_lengths_max]
        return outputs