| | |
| | | self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | self.preencoder = preencoder |
| | | self.postencoder = postencoder |
| | | self.encoder = encoder |
| | | |
| | | if not hasattr(self.encoder, "interctc_use_conditioning"): |
| | | self.encoder.interctc_use_conditioning = False |
| | | if self.encoder.interctc_use_conditioning: |
| | | self.encoder.conditioning_layer = torch.nn.Linear( |
| | | vocab_size, self.encoder.output_size() |
| | | ) |
| | | |
| | | self.error_calculator = None |
| | | |
| | |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = None, None, None, None |
| | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | | for layer_idx, intermediate_out in intermediate_outs: |
| | | # we assume intermediate_out has the same length & padding |
| | | # as those of encoder_out |
| | | loss_ic, cer_ic = self._calc_ctc_loss( |
| | | intermediate_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_interctc = loss_interctc + loss_ic |
| | | |
| | | # Collect Intermedaite CTC stats |
| | | stats["loss_interctc_layer{}".format(layer_idx)] = ( |
| | | loss_ic.detach() if loss_ic is not None else None |
| | | ) |
| | | stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic |
| | | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = ( |
| | | 1 - self.interctc_weight |
| | | ) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # 2b. Attention decoder branch |
| | | if self.ctc_weight != 1.0: |
| | |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # Pre-encoder, e.g. used for raw input data |
| | | if self.preencoder is not None: |
| | | feats, feats_lengths = self.preencoder(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | feats, feats_lengths, ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | # Post-encoder, e.g. NLU |
| | | if self.postencoder is not None: |
| | | encoder_out, encoder_out_lens = self.postencoder( |
| | | encoder_out, encoder_out_lens |
| | | ) |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | |
| | | encoder_out.size(), |
| | | encoder_out_lens.max(), |
| | | ) |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # Pre-encoder, e.g. used for raw input data |
| | | if self.preencoder is not None: |
| | | feats, feats_lengths = self.preencoder(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk( |
| | | feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | # Post-encoder, e.g. NLU |
| | | if self.postencoder is not None: |
| | | encoder_out, encoder_out_lens = self.postencoder( |
| | | encoder_out, encoder_out_lens |
| | | ) |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"]) |
| | | |
| | | return encoder_out, torch.tensor([encoder_out.size(1)]) |
| | | |