| | |
| | | |
| | | audio_mask = kwargs.get("audio_mask", None) |
| | | audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None |
| | | text_token_int = kwargs.get("text_token_int", None) |
| | | if audio_token_lengths is None: |
| | | audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64) |
| | | |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | | enc, enc_lens = self.audio_encoder.encode(**batch) |
| | |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | tokenizer=None) |
| | | if len(kwargs.get("data_type")) > 1: |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0].replace(" ", "") |
| | | text_token_int = tokenizer.encode(text_token_int) |
| | | else: |
| | | text_token_int = None |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int) |
| | | |
| | | # adaptor |
| | | encoder_out = self.adaptor(encoder_out) |