From 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 九月 2024 15:10:50 +0800
Subject: [PATCH] Dev kws (#2105)

---
 funasr/models/ctc/ctc.py |   29 ++++++++++++++++++++++++-----
 1 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/funasr/models/ctc/ctc.py b/funasr/models/ctc/ctc.py
index bdfb3a6..8eb64d1 100644
--- a/funasr/models/ctc/ctc.py
+++ b/funasr/models/ctc/ctc.py
@@ -23,11 +23,17 @@
         ctc_type: str = "builtin",
         reduce: bool = True,
         ignore_nan_grad: bool = True,
+        extra_linear: bool = True,
     ):
         super().__init__()
         eprojs = encoder_output_size
         self.dropout_rate = dropout_rate
-        self.ctc_lo = torch.nn.Linear(eprojs, odim)
+
+        if extra_linear:
+            self.ctc_lo = torch.nn.Linear(eprojs, odim)
+        else:
+            self.ctc_lo = None
+
         self.ctc_type = ctc_type
         self.ignore_nan_grad = ignore_nan_grad
 
@@ -130,7 +136,10 @@
             ys_lens: batch of lengths of character sequence (B)
         """
         # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
-        ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+        if self.ctc_lo is not None:
+            ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+        else:
+            ys_hat = hs_pad
 
         if self.ctc_type == "gtnctc":
             # gtn expects list form for ys
@@ -141,6 +150,7 @@
             # (B, L) -> (BxL,)
             ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
 
+        hlens = hlens.to(hs_pad.device)
         loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
             device=hs_pad.device, dtype=hs_pad.dtype
         )
@@ -155,7 +165,10 @@
         Returns:
             torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
         """
-        return F.softmax(self.ctc_lo(hs_pad), dim=2)
+        if self.ctc_lo is not None:
+            return F.softmax(self.ctc_lo(hs_pad), dim=2)
+        else:
+            return F.softmax(hs_pad, dim=2)
 
     def log_softmax(self, hs_pad):
         """log_softmax of frame activations
@@ -165,7 +178,10 @@
         Returns:
             torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
         """
-        return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+        if self.ctc_lo is not None:
+            return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+        else:
+            return F.log_softmax(hs_pad, dim=2)
 
     def argmax(self, hs_pad):
         """argmax of frame activations
@@ -175,4 +191,7 @@
         Returns:
             torch.Tensor: argmax applied 2d tensor (B, Tmax)
         """
-        return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+        if self.ctc_lo is not None:
+            return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+        else:
+            return torch.argmax(hs_pad, dim=2)

--
Gitblit v1.9.1