| | |
| | | fake_token_len_i = 0 |
| | | fbank_beg_i = -1 |
| | | fbank_lens_i = [] |
| | | speech, speech_lengths = [], [] |
| | | for k, sub_str in enumerate(splits): |
| | | if not sub_str.startswith("<|startofspeech|>"): |
| | | sub_token = tokenizer.encode(sub_str) |
| | |
| | | "<|endofspeech|>", "" |
| | | ) |
| | | if sub_str.startswith("!"): |
| | | sub_str = sub_str[1:] |
| | | if sub_str.startswith("!"): # !!bytes |
| | | sub_str = eval(sub_str[1:]) |
| | | try: |
| | | time1 = time.perf_counter() |
| | | data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs) |
| | | data_src = load_audio_text_image_video(sub_str, fs=frontend.fs) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | except Exception as e: |
| | |
| | | input_source_ids = input_ids + source_ids |
| | | input_ids += source_ids + target_ids |
| | | labels += source_mask + target_ids |
| | | fbank.append(speech[0, :, :]) |
| | | fbank_mask += fbank_mask_i |
| | | fbank_lens.append(speech_lengths) |
| | | if len(speech) > 0: |
| | | fbank.append(speech[0, :, :]) |
| | | fbank_lens.append(speech_lengths) |
| | | |
| | | input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length] |
| | | attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32) |
| | |
| | | source_ids = torch.tensor(input_source_ids, dtype=torch.int64) |
| | | target_ids = torch.tensor(target_ids, dtype=torch.int64) |
| | | |
| | | speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) |
| | | speech_lengths = torch.nn.utils.rnn.pad_sequence( |
| | | fbank_lens, batch_first=True, padding_value=-1 |
| | | ) |
| | | if len(fbank) > 0: |
| | | speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0) |
| | | speech_lengths = torch.nn.utils.rnn.pad_sequence( |
| | | fbank_lens, batch_first=True, padding_value=-1 |
| | | ) |
| | | else: |
| | | speech = [] |
| | | speech_lengths = [] |
| | | output = { |
| | | "speech": speech, |
| | | "speech_lengths": speech_lengths, |
| | |
| | | |
| | | return output |
| | | |
| | | def inference( |
| | | |
| | | def inference_prepare( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | |
| | | |
| | | # audio encoder |
| | | speech = batch["speech"] |
| | | speech_lengths = batch["speech_lengths"][:, 0] |
| | | # fp16 |
| | | if kwargs.get("fp16", False): |
| | | speech = speech.to(torch.float16) |
| | | elif kwargs.get("bf16", False): |
| | | speech = speech.to(torch.bfloat16) |
| | | # audio encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if len(speech) > 0: |
| | | speech_lengths = batch["speech_lengths"][:, 0] |
| | | # fp16 |
| | | if kwargs.get("fp16", False): |
| | | speech = speech.to(torch.float16) |
| | | elif kwargs.get("bf16", False): |
| | | speech = speech.to(torch.bfloat16) |
| | | # audio encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | # audio_adaptor |
| | | encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) |
| | | # audio_adaptor |
| | | encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) |
| | | |
| | | input_ids = batch["input_ids"] |
| | | source_ids = batch["source_ids"] |
| | |
| | | ] = speech_token |
| | | |
| | | speech_idx += 1 |
| | | return inputs_embeds, contents, batch, source_ids, meta_data |
| | | |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare( |
| | | data_in, data_lengths, key, tokenizer, frontend, **kwargs |
| | | ) |
| | | |
| | | llm_dtype = kwargs.get("llm_dtype", "fp32") |
| | | if llm_dtype == "fp32": |