| | |
| | | |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | |
| | | class CTC(torch.nn.Module): |
| | |
| | | reduce: bool = True, |
| | | ignore_nan_grad: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | eprojs = encoder_output_size |
| | | self.dropout_rate = dropout_rate |
| | |
| | | if ignore_nan_grad: |
| | | logging.warning("ignore_nan_grad option is not supported for warp_ctc") |
| | | self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) |
| | | |
| | | elif self.ctc_type == "gtnctc": |
| | | from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction |
| | | |
| | | self.ctc_loss = GTNCTCLossFunction.apply |
| | | else: |
| | | raise ValueError( |
| | | f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}' |