| | |
| | | inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) |
| | | |
| | | batch_size, token_num, dims = inputs_embeds.shape |
| | | _, l, _ = encoder_out.shape |
| | | fbank_mask[fbank_mask < 0] = 0 |
| | | fbank_fake_lens = fbank_mask.sum(-1) |
| | | # _, l, _ = encoder_out.shape |
| | | for batch_idx in range(batch_size): |
| | | |
| | | l = fbank_fake_lens[batch_idx].item() |
| | | fbank_beg_idx = fbank_beg[batch_idx, 0].item() |
| | | min_len = min(l, inputs_embeds.shape[1] - fbank_beg_idx) |
| | | |
| | | inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[ |
| | | batch_idx, :min_len, : |
| | | ] |