| | |
| | | decoder_conf: dict = None, |
| | | ctc: str = None, |
| | | ctc_conf: dict = None, |
| | | ctc_weight: float = 0.5, |
| | | ctc_weight: float = 0.0, |
| | | llm: str = None, |
| | | llm_conf: dict = None, |
| | | adaptor: str = None, |
| | |
| | | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | if ctc_weight > 0.0: |
| | | if ctc_conf is None: |
| | | ctc_conf = {} |
| | | |
| | | ctc = CTC( |
| | | odim=vocab_size, encoder_output_size=adaptor_conf["encoder_dim"], **ctc_conf |
| | | ) |
| | | self.ctc_weight = ctc_weight |
| | | self.ctc = ctc |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | |
| | | stats = {} |
| | | # audio encoder |
| | | encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask) |
| | | outs = self.encode(speech, speech_lengths, audio_mask=audio_mask) |
| | | enc, enc_lens = outs[0], outs[1] |
| | | encoder_out, encoder_out_lens, loss_pre = outs[2], outs[3], outs[4] |
| | | |
| | | |
| | | # decoder: CTC branch |
| | | |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | enc, enc_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None |
| | | |
| | | # adaptor |
| | | encoder_out = self.adaptor(encoder_out) |
| | |
| | | # labels_ids[1:] -> [prompt, input, target, eos] -> [-1, input, target, eos]; |
| | | model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids) |
| | | loss_llm = model_outputs.loss |
| | | stats["loss_llm"] = torch.clone(loss_llm.detach()) |
| | | if self.ctc_weight > 0.0: |
| | | loss_llm = self.ctc_weight * loss_ctc + loss_llm |
| | | loss = loss_llm + loss_pre * self.predictor_weight |
| | | stats = {} |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(model_outputs.logits, -1) |
| | | acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100) |
| | | stats["acc"] = acc_att |
| | | |
| | | |
| | | stats["loss_pre"] = torch.clone(loss_pre.detach()) |
| | | stats["loss_llm"] = torch.clone(loss_llm.detach()) |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | |
| | | if audio_token_lengths is not None: |
| | | loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length) |
| | | |
| | | return pre_acoustic_embeds, pre_token_length, loss_pre |
| | | return enc, enc_lens, pre_acoustic_embeds, pre_token_length, loss_pre |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | # Calc CTC loss |
| | | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | |
| | | # Calc CER using CTC |
| | | cer_ctc = None |
| | | if not self.training and self.error_calculator is not None: |
| | | ys_hat = self.ctc.argmax(encoder_out).data |
| | | cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
| | | return loss_ctc, cer_ctc |
| | | |
| | | def inference(self, |
| | | data_in, |
| | |
| | | else: |
| | | inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) |
| | | |
| | | inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio] |
| | | # inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio, pad] |
| | | inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio] |
| | | attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"]) |
| | | |
| | | # model_outputs = self.llm.generate( |