From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train
---
funasr/models/predictor/cif.py | 105 +++++++++++++++++++++++++++++++++++++++++++++++++---
1 files changed, 98 insertions(+), 7 deletions(-)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 60cf902..c59e245 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -2,6 +2,7 @@
from torch import nn
import logging
import numpy as np
+from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.streaming_utils.utils import sequence_mask
@@ -68,7 +69,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -198,6 +200,94 @@
return acoustic_embeds, token_num, alphas, cif_peak
+ def forward_chunk(self, hidden, cache=None):
+ batch_size, len_time, hidden_size = hidden.shape
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+
+ alphas = alphas.squeeze(-1)
+
+ token_length = []
+ list_fires = []
+ list_frames = []
+ cache_alphas = []
+ cache_hiddens = []
+
+ if cache is not None and "chunk_size" in cache:
+ alphas[:, :cache["chunk_size"][0]] = 0.0
+ alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
+ if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
+ cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
+ cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
+ hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
+ alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
+ if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
+ tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
+ tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
+ tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
+ hidden = torch.cat((hidden, tail_hidden), dim=1)
+ alphas = torch.cat((alphas, tail_alphas), dim=1)
+
+ len_time = alphas.shape[1]
+ for b in range(batch_size):
+ integrate = 0.0
+ frames = torch.zeros((hidden_size), device=hidden.device)
+ list_frame = []
+ list_fire = []
+ for t in range(len_time):
+ alpha = alphas[b][t]
+ if alpha + integrate < self.threshold:
+ integrate += alpha
+ list_fire.append(integrate)
+ frames += alpha * hidden[b][t]
+ else:
+ frames += (self.threshold - integrate) * hidden[b][t]
+ list_frame.append(frames)
+ integrate += alpha
+ list_fire.append(integrate)
+ integrate -= self.threshold
+ frames = integrate * hidden[b][t]
+
+ cache_alphas.append(integrate)
+ if integrate > 0.0:
+ cache_hiddens.append(frames / integrate)
+ else:
+ cache_hiddens.append(frames)
+
+ token_length.append(torch.tensor(len(list_frame), device=alphas.device))
+ list_fires.append(list_fire)
+ list_frames.append(list_frame)
+
+ cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
+ cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
+ cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
+ cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
+
+ max_token_len = max(token_length)
+ if max_token_len == 0:
+ return hidden, torch.stack(token_length, 0)
+ list_ls = []
+ for b in range(batch_size):
+ pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
+ if token_length[b] == 0:
+ list_ls.append(pad_frames)
+ else:
+ list_frames[b] = torch.stack(list_frames[b])
+ list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
+
+ cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
+ cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
+ cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
+ cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
+ return torch.stack(list_ls, 0), torch.stack(token_length, 0)
+
+
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
tail_threshold = self.tail_threshold
@@ -208,7 +298,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -542,9 +633,8 @@
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-
- def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
- target_label_length=None, token_num=None):
+
+ def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
h = hidden
b = hidden.shape[0]
context = h.transpose(1, 2)
@@ -596,7 +686,8 @@
mask_2 = torch.cat([ones_t, mask], dim=1)
mask = mask_2 - mask_1
tail_threshold = mask * tail_threshold
- alphas = torch.cat([alphas, tail_threshold], dim=1)
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
else:
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
tail_threshold = torch.reshape(tail_threshold, (1, 1))
@@ -654,4 +745,4 @@
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
- return predictor_alignments.detach(), predictor_alignments_length.detach()
\ No newline at end of file
+ return predictor_alignments.detach(), predictor_alignments_length.detach()
--
Gitblit v1.9.1