| | |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=None) |
| | | if len(kwargs.get("data_type")) > 1: |
| | | if len(kwargs.get("data_type", [])) > 1: |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0].replace(" ", "") |
| | | text_token_int = tokenizer.encode(text_token_int) |
| | |
| | | audio_mask = kwargs.get("audio_mask", None) |
| | | audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None |
| | | text_token_int = kwargs.get("text_token_int", None) |
| | | if audio_token_lengths is None: |
| | | if audio_token_lengths is None and text_token_int is not None: |
| | | audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64) |
| | | |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | |
| | | mask=enc_mask, |
| | | target_label_length=audio_token_lengths, |
| | | ) |
| | | loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length) |
| | | loss_pre = 0.0 |
| | | if audio_token_lengths is not None: |
| | | loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length) |
| | | |
| | | return pre_acoustic_embeds, pre_token_length, loss_pre |
| | | |
| | |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=None) |
| | | if len(kwargs.get("data_type")) > 1: |
| | | if len(kwargs.get("data_type", [])) > 1: |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0].replace(" ", "") |
| | | text_token_int = text_token_int_list[0] |
| | | text_token_int = tokenizer.encode(text_token_int) |
| | | if text_token_int[0] == tokenizer.bos_token_id: |
| | | text_token_int = text_token_int[1:] |
| | | else: |
| | | text_token_int = None |
| | | time2 = time.perf_counter() |
| | |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int) |
| | | res = self.encode(speech, speech_lengths, text_token_int=text_token_int) |
| | | encoder_out = res[0] |
| | | |
| | | # adaptor |
| | | encoder_out = self.adaptor(encoder_out) |
| | | |
| | | prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt) |
| | | prompt_ids = tokenizer.encode(prompt_pre) |
| | | if prompt_ids[0] == tokenizer.bos_token_id: |
| | | prompt_ids = prompt_ids[1:] |
| | | # prompt_ids = prompt_ids + [tokenizer.pad_token_id] |
| | | prompt_length = len(prompt_ids) |
| | | prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"]) |
| | | pad = torch.tensor([tokenizer.pad_token_id], dtype=torch.int64).to(kwargs["device"]) |
| | | |
| | | if hasattr(self.llm.model, "embed_tokens"): |
| | | inputs_embeds = self.llm.model.embed_tokens(prompt_ids) |
| | | pad = self.llm.model.embed_tokens(pad) |
| | | elif hasattr(self.llm.model.model, "embed_tokens"): |
| | | inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) |
| | | else: |
| | | inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) |
| | | |
| | | inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio] |
| | | inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio] |
| | | attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"]) |
| | | |
| | | # model_outputs = self.llm.generate( |
| | |
| | | preds = torch.argmax(model_outputs.logits, -1) |
| | | text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True) |
| | | |
| | | text = text[0].split(': ')[-1] |
| | | text = text[0].split(':')[-1] |
| | | text = text.strip() |
| | | if text.startswith("Please\n "): |
| | | text = text.replace("Please\n ", "") |
| | | text = text.strip() |
| | | |
| | | # preds = torch.argmax(model_outputs.logits, -1) |
| | | |