zhifu gao
2024-04-08 d19f48e17478be273584853568ac101c994c37e5
funasr/models/llm_asr_nar/model.py
@@ -366,7 +366,7 @@
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        ctc_weight: float = 0.0,
        llm: str = None,
        llm_conf: dict = None,
        adaptor: str = None,
@@ -473,6 +473,15 @@
        
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
        if ctc_weight > 0.0:
            if ctc_conf is None:
                ctc_conf = {}
            ctc = CTC(
                odim=vocab_size, encoder_output_size=adaptor_conf["encoder_dim"], **ctc_conf
            )
        self.ctc_weight = ctc_weight
        self.ctc = ctc
    
    def forward(
        self,
@@ -502,9 +511,23 @@
            speech_lengths = speech_lengths[:, 0]
        
        batch_size = speech.shape[0]
        stats = {}
        # audio encoder
        encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask)
        outs = self.encode(speech, speech_lengths, audio_mask=audio_mask)
        enc, enc_lens = outs[0], outs[1]
        encoder_out, encoder_out_lens, loss_pre = outs[2], outs[3], outs[4]
        # decoder: CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                enc, enc_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
        
        # adaptor
        encoder_out = self.adaptor(encoder_out)
@@ -536,17 +559,19 @@
        # labels_ids[1:] ->  [prompt, input, target, eos] -> [-1, input, target, eos];
        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
        loss_llm = model_outputs.loss
        stats["loss_llm"] = torch.clone(loss_llm.detach())
        if self.ctc_weight > 0.0:
            loss_llm = self.ctc_weight * loss_ctc + loss_llm
        loss = loss_llm + loss_pre * self.predictor_weight
        stats = {}
        with torch.no_grad():
            preds = torch.argmax(model_outputs.logits, -1)
            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
            stats["acc"] = acc_att
        
        stats["loss_pre"] = torch.clone(loss_pre.detach())
        stats["loss_llm"] = torch.clone(loss_llm.detach())
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
@@ -576,7 +601,24 @@
            if audio_token_lengths is not None:
                loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
        
        return pre_acoustic_embeds, pre_token_length, loss_pre
        return enc, enc_lens, pre_acoustic_embeds, pre_token_length, loss_pre
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        # Calc CER using CTC
        cer_ctc = None
        if not self.training and self.error_calculator is not None:
            ys_hat = self.ctc.argmax(encoder_out).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
        return loss_ctc, cer_ctc
    
    def inference(self,
                  data_in,
@@ -648,7 +690,8 @@
        else:
            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
        
        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1)  # [prompt, audio]
        # inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1)  # [prompt, audio, pad]
        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
        
        # model_outputs = self.llm.generate(