From b454a1054fadbff0ee963944ff42f66b98317582 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 08 八月 2023 11:17:43 +0800
Subject: [PATCH] update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)

---
 funasr/export/models/predictor/cif.py |   21 ++++++++++++++-------
 1 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 5ea4a34..dd5dd36 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -36,6 +36,17 @@
 	def forward(self, hidden: torch.Tensor,
 	            mask: torch.Tensor,
 	            ):
+		alphas, token_num = self.forward_cnn(hidden, mask)
+		mask = mask.transpose(-1, -2).float()
+		mask = mask.squeeze(-1)
+		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+		
+		return acoustic_embeds, token_num, alphas, cif_peak
+	
+	def forward_cnn(self, hidden: torch.Tensor,
+	            mask: torch.Tensor,
+	            ):
 		h = hidden
 		context = h.transpose(1, 2)
 		queries = self.pad(context)
@@ -49,12 +60,8 @@
 		alphas = alphas * mask
 		alphas = alphas.squeeze(-1)
 		token_num = alphas.sum(-1)
-		
-		mask = mask.squeeze(-1)
-		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
-		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-		
-		return acoustic_embeds, token_num, alphas, cif_peak
+
+		return alphas, token_num
 	
 	def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
 		b, t, d = hidden.size()
@@ -285,4 +292,4 @@
                                 integrate)
 
     fires = torch.stack(list_fires, 1)
-    return fires
\ No newline at end of file
+    return fires

--
Gitblit v1.9.1