From b454a1054fadbff0ee963944ff42f66b98317582 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 08 八月 2023 11:17:43 +0800
Subject: [PATCH] update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)
---
funasr/export/models/predictor/cif.py | 21 ++++++++++++++-------
1 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 5ea4a34..dd5dd36 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -36,6 +36,17 @@
def forward(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
+ alphas, token_num = self.forward_cnn(hidden, mask)
+ mask = mask.transpose(-1, -2).float()
+ mask = mask.squeeze(-1)
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+
+ return acoustic_embeds, token_num, alphas, cif_peak
+
+ def forward_cnn(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
@@ -49,12 +60,8 @@
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
-
- mask = mask.squeeze(-1)
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- return acoustic_embeds, token_num, alphas, cif_peak
+
+ return alphas, token_num
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
@@ -285,4 +292,4 @@
integrate)
fires = torch.stack(list_fires, 1)
- return fires
\ No newline at end of file
+ return fires
--
Gitblit v1.9.1