From 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 19:50:07 +0800
Subject: [PATCH] update
---
funasr/models/predictor/cif.py | 75 ++++++++++++++++++++++++++++++++++---
1 files changed, 68 insertions(+), 7 deletions(-)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 60cf902..e80a915 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -68,7 +68,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -198,6 +199,65 @@
return acoustic_embeds, token_num, alphas, cif_peak
+ def forward_chunk(self, hidden, cache=None):
+ b, t, d = hidden.size()
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+
+ alphas = alphas.squeeze(-1)
+ mask_chunk_predictor = None
+ if cache is not None:
+ mask_chunk_predictor = None
+ mask_chunk_predictor = torch.zeros_like(alphas)
+ mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0
+
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+
+ if cache is not None:
+ if cache["is_final"]:
+ alphas[:, cache["stride"] + cache["pad_left"] - 1] += 0.45
+ if cache["cif_hidden"] is not None:
+ hidden = torch.cat((cache["cif_hidden"], hidden), 1)
+ if cache["cif_alphas"] is not None:
+ alphas = torch.cat((cache["cif_alphas"], alphas), -1)
+
+ token_num = alphas.sum(-1)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ len_time = alphas.size(-1)
+ last_fire_place = len_time - 1
+ last_fire_remainds = 0.0
+ pre_alphas_length = 0
+
+ mask_chunk_peak_predictor = None
+ if cache is not None:
+ mask_chunk_peak_predictor = None
+ mask_chunk_peak_predictor = torch.zeros_like(cif_peak)
+ if cache["cif_alphas"] is not None:
+ pre_alphas_length = cache["cif_alphas"].size(-1)
+ mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0
+ mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0
+
+ if mask_chunk_peak_predictor is not None:
+ cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1)
+
+ for i in range(len_time):
+ if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
+ last_fire_place = len_time - 1 - i
+ last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
+ break
+ last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
+ cache["cif_hidden"] = hidden[:, last_fire_place:, :]
+ cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
+ token_num_int = token_num.floor().type(torch.int32).item()
+ return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
+
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
@@ -208,7 +268,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -542,9 +603,8 @@
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-
- def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
- target_label_length=None, token_num=None):
+
+ def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
@@ -596,7 +656,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -654,4 +715,4 @@
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
- return predictor_alignments.detach(), predictor_alignments_length.detach()
\ No newline at end of file
+ return predictor_alignments.detach(), predictor_alignments_length.detach()
--
Gitblit v1.9.1