| | |
| | | self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) |
| | | self.ctc_dropout_rate = auxiliary_ctc_dropout_rate |
| | | |
| | | if self.use_auxiliary_lm_loss: |
| | |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None: |
| | | encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, |
| | | chunk_outs=None) |
| | | # 2. Transducer-related I/O preparation |
| | | decoder_in, target, t_len, u_len = get_transducer_task_io( |
| | | text, |
| | |
| | | self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size) |
| | | self.ctc_dropout_rate = auxiliary_ctc_dropout_rate |
| | | |
| | | if self.use_auxiliary_att: |