| | |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | | ) |
| | | decoder_out, _, attn = decoder_outs[0], decoder_outs[1], decoder_outs[2] |
| | | |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | ''' |
| | | if self.crit_attn_weight > 0 and attn.shape[-1] > 1: |
| | | ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0) |
| | | attn_non_blank = attn[:,:,:,:-1] |
| | |
| | | loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device)) |
| | | else: |
| | | loss_ideal = None |
| | | |
| | | ''' |
| | | loss_ideal = None |
| | | |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | | # 2. Compute attention loss |