zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/datasets/sense_voice_datasets/datasets.py
@@ -10,21 +10,26 @@
    """
    SenseVoiceDataset
    """
    def __init__(self,
    def __init__(
        self,
                 path,
                 index_ds: str = None,
                 frontend=None,
                 tokenizer=None,
                 int_pad_value: int = -1,
                 float_pad_value: float = 0.0,
                  **kwargs):
        **kwargs,
    ):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
            preprocessor_speech = preprocessor_speech_class(
                **kwargs.get("preprocessor_speech_conf")
            )
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
@@ -61,7 +66,9 @@
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
        speech, speech_lengths = extract_fbank(
            data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
        )  # speech: [b, T, d]
        speech = speech.permute(0, 2, 1)
        target = item["target"]
        if self.preprocessor_text:
@@ -85,16 +92,18 @@
        text = torch.tensor(ids, dtype=torch.int64)
        text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
        target_mask = [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
        target_mask = (
            [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
        target_mask = torch.tensor(target_mask, dtype=torch.float32)
        return {"speech": speech[0, :, :],
        return {
            "speech": speech[0, :, :],
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
                "target_mask": target_mask,
                }
    
    def collator(self, samples: list=None):
        outputs = {}
@@ -112,7 +121,7 @@
                else:
                    pad_value = self.float_pad_value
                
                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )
        return outputs