From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/export/models/predictor/cif.py |   23 +++++++++++++++--------
 1 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 5ea4a34..03c4433 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -36,6 +36,17 @@
 	def forward(self, hidden: torch.Tensor,
 	            mask: torch.Tensor,
 	            ):
+		alphas, token_num = self.forward_cnn(hidden, mask)
+		mask = mask.transpose(-1, -2).float()
+		mask = mask.squeeze(-1)
+		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+		
+		return acoustic_embeds, token_num, alphas, cif_peak
+	
+	def forward_cnn(self, hidden: torch.Tensor,
+	            mask: torch.Tensor,
+	            ):
 		h = hidden
 		context = h.transpose(1, 2)
 		queries = self.pad(context)
@@ -49,12 +60,8 @@
 		alphas = alphas * mask
 		alphas = alphas.squeeze(-1)
 		token_num = alphas.sum(-1)
-		
-		mask = mask.squeeze(-1)
-		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
-		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-		
-		return acoustic_embeds, token_num, alphas, cif_peak
+
+		return alphas, token_num
 	
 	def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
 		b, t, d = hidden.size()
@@ -281,8 +288,8 @@
 
         fire_place = integrate >= threshold
         integrate = torch.where(fire_place,
-                                integrate - torch.ones([batch_size], device=alphas.device),
+                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
                                 integrate)
 
     fires = torch.stack(list_fires, 1)
-    return fires
\ No newline at end of file
+    return fires

--
Gitblit v1.9.1