语帆
2024-02-21 62178770dccdbf5da42e831898ea32adeeacba45
funasr/models/contextual_paraformer/model.py
@@ -102,17 +102,16 @@
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        pdb.set_trace()
        batch_size = speech.shape[0]
        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        dha_pad = kwargs.get("dha_pad")
        pdb.set_trace()
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        pdb.set_trace()
        loss_ctc, cer_ctc = None, None
        
        stats = dict()
@@ -127,12 +126,11 @@
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        
        pdb.set_trace()
        # 2b. Attention decoder branch
        loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
            encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
        )
        pdb.set_trace()
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att + loss_pre * self.predictor_weight
@@ -170,26 +168,24 @@
    ):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        pdb.set_trace()
        if self.predictor_bias == 1:
            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
            ys_pad_lens = ys_pad_lens + self.predictor_bias
        pdb.set_trace()
        pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                     ignore_id=self.ignore_id)
        pdb.set_trace()
        # -1. bias encoder
        if self.use_decoder_embedding:
            hw_embed = self.decoder.embed(hotword_pad)
        else:
            hw_embed = self.bias_embed(hotword_pad)
        pdb.set_trace()
        hw_embed, (_, _) = self.bias_encoder(hw_embed)
        pdb.set_trace()
        _ind = np.arange(0, hotword_pad.shape[0]).tolist()
        selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
        contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
        pdb.set_trace()
        # 0. sampler
        decoder_out_1st = None
        if self.sampling_ratio > 0.0:
@@ -201,7 +197,7 @@
            if self.step_cur < 2:
                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds = pre_acoustic_embeds
        pdb.set_trace()
        # 1. Forward decoder
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -217,7 +213,7 @@
            loss_ideal = None
        '''
        loss_ideal = None
        pdb.set_trace()
        if decoder_out_1st is None:
            decoder_out_1st = decoder_out
        # 2. Compute attention loss
@@ -294,11 +290,11 @@
                                                               enforce_sorted=False)
            _, (h_n, _) = self.bias_encoder(hw_embed)
            hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
        pdb.set_trace()
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
        )
        pdb.set_trace()
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
@@ -363,14 +359,11 @@
                                                                 clas_scale=kwargs.get("clas_scale", 1.0))
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        
        pdb.set_trace()
        results = []
        b, n, d = decoder_out.size()
        pdb.set_trace()
        for i in range(b):
            x = encoder_out[i, :encoder_out_lens[i], :]
            am_scores = decoder_out[i, :pre_token_length[i], :]
            pdb.set_trace()
            if self.beam_search is not None:
                nbest_hyps = self.beam_search(
                    x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),