From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/export/models/predictor/cif.py | 23 +++++++++++++++--------
1 files changed, 15 insertions(+), 8 deletions(-)
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 5ea4a34..03c4433 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -36,6 +36,17 @@
def forward(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
+ alphas, token_num = self.forward_cnn(hidden, mask)
+ mask = mask.transpose(-1, -2).float()
+ mask = mask.squeeze(-1)
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+
+ return acoustic_embeds, token_num, alphas, cif_peak
+
+ def forward_cnn(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
@@ -49,12 +60,8 @@
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
-
- mask = mask.squeeze(-1)
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- return acoustic_embeds, token_num, alphas, cif_peak
+
+ return alphas, token_num
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
@@ -281,8 +288,8 @@
fire_place = integrate >= threshold
integrate = torch.where(fire_place,
- integrate - torch.ones([batch_size], device=alphas.device),
+ integrate - torch.ones([batch_size], device=alphas.device)*threshold,
integrate)
fires = torch.stack(list_fires, 1)
- return fires
\ No newline at end of file
+ return fires
--
Gitblit v1.9.1