| | |
| | | |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info, ret_attn=(ideal_attn is not None) |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out, _, attn = decoder_outs[0], decoder_outs[1], decoder_outs[2] |
| | | |
| | | if self.crit_attn_weight > 0 and attn.shape[-1] > 1: |
| | | ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0) |
| | | attn_non_blank = attn[:,:,:,:-1] |