| | |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | batch_size, frames, _ = speech.shape |
| | | |
| | | # audio encoder |
| | | encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths) |
| | |
| | | # audio_adaptor |
| | | encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens) |
| | | |
| | | input_ids[input_ids == -1] = 0 |
| | | input_ids[input_ids == -100] = 0 |
| | | if hasattr(self.llm.model, "embed_tokens"): |
| | | inputs_embeds = self.llm.model.embed_tokens(input_ids) |
| | | elif hasattr(self.llm.model.model, "embed_tokens"): |
| | | inputs_embeds = self.llm.model.model.embed_tokens(input_ids) |
| | | else: |
| | | inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) |
| | | input_ids[input_ids < 0] = 0 |
| | | inputs_embeds = self.llm.model.get_input_embeddings()(input_ids) |
| | | |
| | | batch_size, token_num, dims = inputs_embeds.shape |
| | | _, l, _ = encoder_out.shape |
| | | fbank_mask[fbank_mask < 0] = 0 |
| | | fbank_fake_lens = fbank_mask.sum(-1).to(torch.int32) |
| | | # _, l, _ = encoder_out.shape |
| | | for batch_idx in range(batch_size): |
| | | fbank_beg_idx = fbank_beg[batch_idx, 0].item() |
| | | inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + l, :] = encoder_out[ |
| | | batch_idx, :l, : |
| | | ] |
| | | |
| | | fbank_fake_len = fbank_fake_lens[batch_idx].item() |
| | | fbank_beg_idx = fbank_beg[batch_idx, 0].item() |
| | | min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx) |
| | | fbank_fake_len = encoder_out_lens[batch_idx].item() |
| | | min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx) |
| | | try: |
| | | inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[ |
| | | batch_idx, :min_len, : |
| | | ] |
| | | except Exception as e: |
| | | logging.error(f"{str(e)}, {traceback.format_exc()}") |
| | | logging.info( |
| | | f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}" |
| | | ) |
| | | fbank_fake_len = encoder_out_lens[batch_idx].item() |
| | | min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx) |
| | | inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[ |
| | | batch_idx, :min_len, : |
| | | ] |
| | | |
| | | labels_ids[labels_ids == -1] = -100 |
| | | model_outputs = self.llm( |
| | | inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids |
| | | ) |
| | |
| | | stats["acc"] = acc_att |
| | | |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | stats["batch_size_x_frames"] = frames * batch_size |
| | | stats["batch_size_real_frames"] = speech_lengths.sum().item() |
| | | stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"] |
| | | stats["batch_size_x_tokens"] = token_num * batch_size |
| | | stats["batch_size_real_tokens"] = attention_mask.sum().item() |
| | | stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"] |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |