jmwang66
2023-06-29 98abc0e5ac1a1da0fe1802d9ffb623802fbf0b2f
funasr/models/ctc.py
@@ -2,7 +2,6 @@
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
class CTC(torch.nn.Module):
@@ -25,7 +24,6 @@
        reduce: bool = True,
        ignore_nan_grad: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        eprojs = encoder_output_size
        self.dropout_rate = dropout_rate
@@ -41,11 +39,6 @@
            if ignore_nan_grad:
                logging.warning("ignore_nan_grad option is not supported for warp_ctc")
            self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
        elif self.ctc_type == "gtnctc":
            from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
            self.ctc_loss = GTNCTCLossFunction.apply
        else:
            raise ValueError(
                f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}'