游雁
2024-02-23 fa6f60fa762f271d096b8749f3cc9bfc61a6ed48
funasr/datasets/llm_datasets/datasets.py
@@ -24,12 +24,12 @@
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
        self.preprocessor_text = preprocessor_text
        
        self.frontend = frontend
@@ -43,6 +43,7 @@
        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
            self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
        self.prompt_af = ""
        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
    
    def get_source_len(self, index):
        item = self.index_ds[index]
@@ -64,7 +65,7 @@
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
        speech = speech.sequeeze(0)
        speech = speech.squeeze(0)
        target = item["target"]
        if self.preprocessor_text:
@@ -91,10 +92,10 @@
        label_mask = labels_ids.ge(0)  # [False,False,True,True]
        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
        
        audio_mask = [0] * prompt_pre_length + [1] * audio_length
        torch.tensor(audio_mask, dtype=torch.float32)
        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
        
        ids = self.tokenizer.encode(target)
        ids = self.tokenizer.encode(target) # token ids is different from labels_ids
        text = torch.tensor(ids, dtype=torch.int64)
        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)