From 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 九月 2024 15:10:50 +0800
Subject: [PATCH] Dev kws (#2105)
---
funasr/models/ctc/ctc.py | 29 ++++++++++++++++++++++++-----
1 files changed, 24 insertions(+), 5 deletions(-)
diff --git a/funasr/models/ctc/ctc.py b/funasr/models/ctc/ctc.py
index bdfb3a6..8eb64d1 100644
--- a/funasr/models/ctc/ctc.py
+++ b/funasr/models/ctc/ctc.py
@@ -23,11 +23,17 @@
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = True,
+ extra_linear: bool = True,
):
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
- self.ctc_lo = torch.nn.Linear(eprojs, odim)
+
+ if extra_linear:
+ self.ctc_lo = torch.nn.Linear(eprojs, odim)
+ else:
+ self.ctc_lo = None
+
self.ctc_type = ctc_type
self.ignore_nan_grad = ignore_nan_grad
@@ -130,7 +136,10 @@
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
- ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+ if self.ctc_lo is not None:
+ ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+ else:
+ ys_hat = hs_pad
if self.ctc_type == "gtnctc":
# gtn expects list form for ys
@@ -141,6 +150,7 @@
# (B, L) -> (BxL,)
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
+ hlens = hlens.to(hs_pad.device)
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
device=hs_pad.device, dtype=hs_pad.dtype
)
@@ -155,7 +165,10 @@
Returns:
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
"""
- return F.softmax(self.ctc_lo(hs_pad), dim=2)
+ if self.ctc_lo is not None:
+ return F.softmax(self.ctc_lo(hs_pad), dim=2)
+ else:
+ return F.softmax(hs_pad, dim=2)
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
@@ -165,7 +178,10 @@
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
- return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+ if self.ctc_lo is not None:
+ return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+ else:
+ return F.log_softmax(hs_pad, dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
@@ -175,4 +191,7 @@
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
- return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+ if self.ctc_lo is not None:
+ return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+ else:
+ return torch.argmax(hs_pad, dim=2)
--
Gitblit v1.9.1