| | |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): |
| | | |
| | | decoder_out, _ = self.decoder( |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens |
| | | ) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | | |
| | |
| | | postencoder: Optional[AbsPostEncoder], |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | joint_network: Optional[torch.nn.Module], |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | ignore_id: int = -1, |
| | |
| | | postencoder=postencoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | joint_network=joint_network, |
| | | ctc_weight=ctc_weight, |
| | | interctc_weight=interctc_weight, |
| | | ignore_id=ignore_id, |