| | |
| | | ctc_type: str = "builtin", |
| | | reduce: bool = True, |
| | | ignore_nan_grad: bool = True, |
| | | extra_linear: bool = True, |
| | | ): |
| | | super().__init__() |
| | | eprojs = encoder_output_size |
| | | self.dropout_rate = dropout_rate |
| | | self.ctc_lo = torch.nn.Linear(eprojs, odim) |
| | | |
| | | if extra_linear: |
| | | self.ctc_lo = torch.nn.Linear(eprojs, odim) |
| | | else: |
| | | self.ctc_lo = None |
| | | |
| | | self.ctc_type = ctc_type |
| | | self.ignore_nan_grad = ignore_nan_grad |
| | | |
| | |
| | | ys_lens: batch of lengths of character sequence (B) |
| | | """ |
| | | # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) |
| | | ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) |
| | | if self.ctc_lo is not None: |
| | | ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) |
| | | else: |
| | | ys_hat = hs_pad |
| | | |
| | | if self.ctc_type == "gtnctc": |
| | | # gtn expects list form for ys |
| | |
| | | # (B, L) -> (BxL,) |
| | | ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)]) |
| | | |
| | | hlens = hlens.to(hs_pad.device) |
| | | loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to( |
| | | device=hs_pad.device, dtype=hs_pad.dtype |
| | | ) |
| | |
| | | Returns: |
| | | torch.Tensor: softmax applied 3d tensor (B, Tmax, odim) |
| | | """ |
| | | return F.softmax(self.ctc_lo(hs_pad), dim=2) |
| | | if self.ctc_lo is not None: |
| | | return F.softmax(self.ctc_lo(hs_pad), dim=2) |
| | | else: |
| | | return F.softmax(hs_pad, dim=2) |
| | | |
| | | def log_softmax(self, hs_pad): |
| | | """log_softmax of frame activations |
| | |
| | | Returns: |
| | | torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) |
| | | """ |
| | | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) |
| | | if self.ctc_lo is not None: |
| | | return F.log_softmax(self.ctc_lo(hs_pad), dim=2) |
| | | else: |
| | | return F.log_softmax(hs_pad, dim=2) |
| | | |
| | | def argmax(self, hs_pad): |
| | | """argmax of frame activations |
| | |
| | | Returns: |
| | | torch.Tensor: argmax applied 2d tensor (B, Tmax) |
| | | """ |
| | | return torch.argmax(self.ctc_lo(hs_pad), dim=2) |
| | | if self.ctc_lo is not None: |
| | | return torch.argmax(self.ctc_lo(hs_pad), dim=2) |
| | | else: |
| | | return torch.argmax(hs_pad, dim=2) |