zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/datasets/llm_datasets_vicuna/datasets.py
@@ -11,21 +11,25 @@
    AudioLLMDataset
    """
    
    def __init__(self,
    def __init__(
        self,
                 path,
                 index_ds: str = None,
                 frontend=None,
                 tokenizer=None,
                 int_pad_value: int = -1,
                 float_pad_value: float = 0.0,
                 **kwargs):
        **kwargs
    ):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
            preprocessor_speech = preprocessor_speech_class(
                **kwargs.get("preprocessor_speech_conf", {})
            )
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
@@ -49,7 +53,6 @@
        self.prompt_template = "USER: {}\n ASSISTANT:"
        self.answer_template = "{}"
        
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
@@ -69,11 +72,16 @@
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
                                               is_final=True)  # speech: [b, T, d]
        speech, speech_lengths = extract_fbank(
            data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
        )  # speech: [b, T, d]
        speech = speech.squeeze(0)
        audio_pseudo_length = (speech.shape[0] + 1) // self.audio_adaptor_downsample_rate // self.audio_encoder_downsample_rate
        audio_pseudo_length = (
            (speech.shape[0] + 1)
            // self.audio_adaptor_downsample_rate
            // self.audio_encoder_downsample_rate
        )
        audio_pseudo = torch.full((audio_pseudo_length,), -1) # placeholder
        
        target = item["target"]
@@ -119,7 +127,8 @@
        text = torch.tensor(ids, dtype=torch.int64)
        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
        
        return {"speech": speech,
        return {
            "speech": speech,
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
@@ -146,5 +155,7 @@
                else:
                    pad_value = self.float_pad_value
                
                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )
        return outputs