| | |
| | | dropout_rate: float = 0.0, |
| | | embed_dropout_rate: float = 0.0, |
| | | embed_pad: int = 0, |
| | | use_embed_mask: bool = False, |
| | | ) -> None: |
| | | """Construct a RNNDecoder object.""" |
| | | super().__init__() |
| | |
| | | |
| | | self.device = next(self.parameters()).device |
| | | self.score_cache = {} |
| | | |
| | | self.use_embed_mask = use_embed_mask |
| | | if self.use_embed_mask: |
| | | self._embed_mask = SpecAug( |
| | | time_mask_width_range=3, |
| | | num_time_mask=4, |
| | | apply_freq_mask=False, |
| | | apply_time_warp=False |
| | | ) |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | states = self.init_state(labels.size(0)) |
| | | |
| | | dec_embed = self.dropout_embed(self.embed(labels)) |
| | | if self.use_embed_mask and self.training: |
| | | dec_embed = self._embed_mask(dec_embed, label_lens)[0] |
| | | dec_out, states = self.rnn_forward(dec_embed, states) |
| | | return dec_out |
| | | |