funasr/models/seaco_paraformer/model.py
@@ -157,7 +157,7 @@ seaco_label_pad, ) if self.train_decoder: loss_att, acc_att = self._calc_att_loss( loss_att, acc_att, _, _, _ = self._calc_att_loss( encoder_out, encoder_out_lens, text, text_lengths ) loss = loss_seaco + loss_att * self.seaco_weight @@ -350,7 +350,7 @@ pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1] pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] return ([],) decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,