manyeyes
2023-06-21 c2a2575f198b1bfd452ea5769bec81bcce3d3a42
funasr/models/e2e_asr_paraformer.py
@@ -242,7 +242,7 @@
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
        if self.use_1st_decoder_loss and pre_loss_att is not None:
            loss = loss + pre_loss_att
            loss = loss + (1 - self.ctc_weight) * pre_loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1160,8 +1160,8 @@
                                                                                           mask_chunk_predictor=mask_chunk_predictor,
                                                                                           target_label_length=None,
                                                                                           )
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1],
                                                                                             encoder_out_lens)
        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
                                                                                             encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
        scama_mask = None
        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':