| | |
| | | |
| | | def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens): |
| | | |
| | | decoder_out, _ = self.decoder( |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens |
| | | ) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | | |
| | |
| | | def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds): |
| | | |
| | | tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device) |
| | | ys_pad *= tgt_mask[:, :, 0] |
| | | ys_pad = ys_pad * tgt_mask[:, :, 0] |
| | | ys_pad_embed = self.decoder.embed(ys_pad) |
| | | with torch.no_grad(): |
| | | decoder_outs = self.decoder( |
| | |
| | | postencoder: Optional[AbsPostEncoder], |
| | | decoder: AbsDecoder, |
| | | ctc: CTC, |
| | | joint_network: Optional[torch.nn.Module], |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | ignore_id: int = -1, |
| | |
| | | postencoder=postencoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | joint_network=joint_network, |
| | | ctc_weight=ctc_weight, |
| | | interctc_weight=interctc_weight, |
| | | ignore_id=ignore_id, |