| | |
| | | """ |
| | | if self.criterion_transducer is None: |
| | | try: |
| | | # from warprnnt_pytorch import RNNTLoss |
| | | # self.criterion_transducer = RNNTLoss( |
| | | # reduction="mean", |
| | | # fastemit_lambda=self.fastemit_lambda, |
| | | # ) |
| | | from warp_rnnt import rnnt_loss as RNNTLoss |
| | | self.criterion_transducer = RNNTLoss |
| | | |
| | |
| | | ) |
| | | exit(1) |
| | | |
| | | # loss_transducer = self.criterion_transducer( |
| | | # joint_out, |
| | | # target, |
| | | # t_len, |
| | | # u_len, |
| | | # ) |
| | | log_probs = torch.log_softmax(joint_out, dim=-1) |
| | | |
| | | loss_transducer = self.criterion_transducer( |
| | |
| | | |
| | | batch_size = speech.shape[0] |
| | | text = text[:, : text_lengths.max()] |
| | | #print(speech.shape) |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | |
| | | |
| | | loss_trans = loss_trans_utt + loss_trans_chunk |
| | | loss_ctc = loss_ctc + loss_ctc_chunk |
| | | loss_ctc = loss_att + loss_att_chunk |
| | | loss_att = loss_att + loss_att_chunk |
| | | |
| | | loss = ( |
| | | self.transducer_weight * loss_trans |
| | |
| | | """ |
| | | if self.criterion_transducer is None: |
| | | try: |
| | | # from warprnnt_pytorch import RNNTLoss |
| | | # self.criterion_transducer = RNNTLoss( |
| | | # reduction="mean", |
| | | # fastemit_lambda=self.fastemit_lambda, |
| | | # ) |
| | | from warp_rnnt import rnnt_loss as RNNTLoss |
| | | self.criterion_transducer = RNNTLoss |
| | | |
| | |
| | | ) |
| | | exit(1) |
| | | |
| | | # loss_transducer = self.criterion_transducer( |
| | | # joint_out, |
| | | # target, |
| | | # t_len, |
| | | # u_len, |
| | | # ) |
| | | log_probs = torch.log_softmax(joint_out, dim=-1) |
| | | |
| | | loss_transducer = self.criterion_transducer( |
| | |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | |
| | | return loss_att, acc_att |
| | | return loss_att, acc_att |