| | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | |
| | | self.ignore_id = ignore_id |
| | | |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | |
| | | self.encoder = encoder |
| | | |
| | | self.decoder = decoder |
| | |
| | | |
| | | self.error_calculator = None |
| | | |
| | | self.share_embedding = share_embedding |
| | | if self.share_embedding: |
| | | self.decoder.embed = None |
| | | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | ys_pad[ys_pad == -1] = 0 |
| | | decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | if isinstance(decoder_out, (list, tuple)): |
| | | decoder_out = decoder_out[0] |