zhifu gao
2024-09-25 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d
funasr/models/ctc/ctc.py
@@ -23,11 +23,17 @@
        ctc_type: str = "builtin",
        reduce: bool = True,
        ignore_nan_grad: bool = True,
        extra_linear: bool = True,
    ):
        super().__init__()
        eprojs = encoder_output_size
        self.dropout_rate = dropout_rate
        self.ctc_lo = torch.nn.Linear(eprojs, odim)
        if extra_linear:
            self.ctc_lo = torch.nn.Linear(eprojs, odim)
        else:
            self.ctc_lo = None
        self.ctc_type = ctc_type
        self.ignore_nan_grad = ignore_nan_grad
@@ -130,7 +136,10 @@
            ys_lens: batch of lengths of character sequence (B)
        """
        # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
        ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
        if self.ctc_lo is not None:
            ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
        else:
            ys_hat = hs_pad
        if self.ctc_type == "gtnctc":
            # gtn expects list form for ys
@@ -141,6 +150,7 @@
            # (B, L) -> (BxL,)
            ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
        hlens = hlens.to(hs_pad.device)
        loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
            device=hs_pad.device, dtype=hs_pad.dtype
        )
@@ -155,7 +165,10 @@
        Returns:
            torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
        """
        return F.softmax(self.ctc_lo(hs_pad), dim=2)
        if self.ctc_lo is not None:
            return F.softmax(self.ctc_lo(hs_pad), dim=2)
        else:
            return F.softmax(hs_pad, dim=2)
    def log_softmax(self, hs_pad):
        """log_softmax of frame activations
@@ -165,7 +178,10 @@
        Returns:
            torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
        """
        return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
        if self.ctc_lo is not None:
            return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
        else:
            return F.log_softmax(hs_pad, dim=2)
    def argmax(self, hs_pad):
        """argmax of frame activations
@@ -175,4 +191,7 @@
        Returns:
            torch.Tensor: argmax applied 2d tensor (B, Tmax)
        """
        return torch.argmax(self.ctc_lo(hs_pad), dim=2)
        if self.ctc_lo is not None:
            return torch.argmax(self.ctc_lo(hs_pad), dim=2)
        else:
            return torch.argmax(hs_pad, dim=2)