From b8bf792ce7df411ae4ed8d2bd8c8eba7c59e082b Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期三, 10 四月 2024 11:37:27 +0800
Subject: [PATCH] fix bug
---
funasr/models/paraformer/cif_predictor.py | 308 ++++++++++++++++++++++++++++++--------------------
1 files changed, 185 insertions(+), 123 deletions(-)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 60ddc24..d538e21 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -10,7 +10,7 @@
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-
+from torch.cuda.amp import autocast
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(torch.nn.Module):
@@ -28,42 +28,44 @@
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- memory = self.cif_conv1d(queries)
- output = memory + context
- output = self.dropout(output)
- output = output.transpose(1, 2)
- output = torch.relu(output)
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- if mask is not None:
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- mask = mask.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- else:
- target_length = None
- token_num = alphas.sum(-1)
- if target_length is not None:
- alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
- elif self.tail_threshold > 0.0:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ with autocast(False):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ memory = self.cif_conv1d(queries)
+ output = memory + context
+ output = self.dropout(output)
+ output = output.transpose(1, 2)
+ output = torch.relu(output)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
+ else:
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- if target_length is None and self.tail_threshold > 0.0:
- token_num_int = torch.max(token_num).type(torch.int32).item()
- acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
-
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+
return acoustic_embeds, token_num, alphas, cif_peak
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
@@ -153,7 +155,7 @@
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
tail_mask=True,
):
- super(CifPredictorV2, self).__init__()
+ super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
@@ -169,41 +171,43 @@
def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
target_label_length=None):
- h = hidden
- context = h.transpose(1, 2)
- queries = self.pad(context)
- output = torch.relu(self.cif_conv1d(queries))
- output = output.transpose(1, 2)
-
- output = self.cif_output(output)
- alphas = torch.sigmoid(output)
- alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
- if mask is not None:
- mask = mask.transpose(-1, -2).float()
- alphas = alphas * mask
- if mask_chunk_predictor is not None:
- alphas = alphas * mask_chunk_predictor
- alphas = alphas.squeeze(-1)
- mask = mask.squeeze(-1)
- if target_label_length is not None:
- target_length = target_label_length.squeeze(-1)
- elif target_label is not None:
- target_length = (target_label != ignore_id).float().sum(-1)
- else:
- target_length = None
- token_num = alphas.sum(-1)
- if target_length is not None:
- alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
- elif self.tail_threshold > 0.0:
- if self.tail_mask:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+ with autocast(False):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ mask = mask.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length.squeeze(-1)
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
else:
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
-
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
- if target_length is None and self.tail_threshold > 0.0:
- token_num_int = torch.max(token_num).type(torch.int32).item()
- acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+ elif self.tail_threshold > 0.0:
+ if self.tail_mask:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+ else:
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=None)
+
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+ if target_length is None and self.tail_threshold > 0.0:
+ token_num_int = torch.max(token_num).type(torch.int32).item()
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
@@ -371,61 +375,119 @@
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
- def gen_tf2torch_map_dict(self):
+@tables.register("predictor_classes", "CifPredictorV2Export")
+class CifPredictorV2Export(torch.nn.Module):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+
+ self.pad = model.pad
+ self.cif_conv1d = model.cif_conv1d
+ self.cif_output = model.cif_output
+ self.threshold = model.threshold
+ self.smooth_factor = model.smooth_factor
+ self.noise_threshold = model.noise_threshold
+ self.tail_threshold = model.tail_threshold
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## predictor
- "{}.cif_conv1d.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (2, 1, 0),
- }, # (256,256,3),(3,256,256)
- "{}.cif_conv1d.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.cif_output.weight".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1,256),(1,256,1)
- "{}.cif_output.bias".format(tensor_name_prefix_torch):
- {"name": "{}/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1,),(1,)
- }
- return map_dict_local
+ def forward(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
+ alphas, token_num = self.forward_cnn(hidden, mask)
+ mask = mask.transpose(-1, -2).float()
+ mask = mask.squeeze(-1)
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+ acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+
+ return acoustic_embeds, token_num, alphas, cif_peak
+
+ def forward_cnn(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ alphas = alphas.squeeze(-1)
+ token_num = alphas.sum(-1)
+
+ return alphas, token_num
+
+ def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+ b, t, d = hidden.size()
+ tail_threshold = self.tail_threshold
+
+ zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+ ones_t = torch.ones_like(zeros_t)
+
+ mask_1 = torch.cat([mask, zeros_t], dim=1)
+ mask_2 = torch.cat([ones_t, mask], dim=1)
+ mask = mask_2 - mask_1
+ tail_threshold = mask * tail_threshold
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
+
+ zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+ hidden = torch.cat([hidden, zeros], dim=1)
+ token_num = alphas.sum(dim=-1)
+ token_num_floor = torch.floor(token_num)
+
+ return hidden, alphas, token_num_floor
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
- map_dict = self.gen_tf2torch_map_dict()
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
+@torch.jit.script
+def cif_export(hidden, alphas, threshold: float):
+ batch_size, len_time, hidden_size = hidden.size()
+ threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
- return var_dict_torch_update
+ # loop varss
+ integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
+ frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
+ # intermediate vars along time
+ list_fires = []
+ list_frames = []
+
+ for t in range(len_time):
+ alpha = alphas[:, t]
+ distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
+
+ integrate += alpha
+ list_fires.append(integrate)
+
+ fire_place = integrate >= threshold
+ integrate = torch.where(fire_place,
+ integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
+ integrate)
+ cur = torch.where(fire_place,
+ distribution_completion,
+ alpha)
+ remainds = alpha - cur
+
+ frame += cur[:, None] * hidden[:, t, :]
+ list_frames.append(frame)
+ frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+ remainds[:, None] * hidden[:, t, :],
+ frame)
+
+ fires = torch.stack(list_fires, 1)
+ frames = torch.stack(list_frames, 1)
+
+ fire_idxs = fires >= threshold
+ frame_fires = torch.zeros_like(hidden)
+ max_label_len = frames[0, fire_idxs[0]].size(0)
+ for b in range(batch_size):
+ frame_fire = frames[b, fire_idxs[b]]
+ frame_len = frame_fire.size(0)
+ frame_fires[b, :frame_len, :] = frame_fire
+
+ if frame_len >= max_label_len:
+ max_label_len = frame_len
+ frame_fires = frame_fires[:, :max_label_len, :]
+ return frame_fires, fires
class mae_loss(torch.nn.Module):
--
Gitblit v1.9.1