From 30c40c643c19f6e2ac8679fa76d09d0f9ceccc65 Mon Sep 17 00:00:00 2001
From: chenmengzheAAA <123789350+chenmengzheAAA@users.noreply.github.com>
Date: 星期四, 14 九月 2023 18:00:43 +0800
Subject: [PATCH] Update modelscope_models.md
---
funasr/models/predictor/cif.py | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 218 insertions(+), 1 deletions(-)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 5615373..5f18c4d 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -1,9 +1,12 @@
import torch
from torch import nn
+from torch import Tensor
import logging
import numpy as np
+from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.streaming_utils.utils import sequence_mask
+from typing import Optional, Tuple
class CifPredictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
@@ -198,6 +201,95 @@
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
return acoustic_embeds, token_num, alphas, cif_peak
+
+ def forward_chunk(self, hidden, cache=None):
+ batch_size, len_time, hidden_size = hidden.shape
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+
+ alphas = alphas.squeeze(-1)
+
+ token_length = []
+ list_fires = []
+ list_frames = []
+ cache_alphas = []
+ cache_hiddens = []
+
+ if cache is not None and "chunk_size" in cache:
+ alphas[:, :cache["chunk_size"][0]] = 0.0
+ if "is_final" in cache and not cache["is_final"]:
+ alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
+ if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
+ cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
+ cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
+ hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
+ alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
+ if cache is not None and "is_final" in cache and cache["is_final"]:
+ tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
+ tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
+ tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
+ hidden = torch.cat((hidden, tail_hidden), dim=1)
+ alphas = torch.cat((alphas, tail_alphas), dim=1)
+
+ len_time = alphas.shape[1]
+ for b in range(batch_size):
+ integrate = 0.0
+ frames = torch.zeros((hidden_size), device=hidden.device)
+ list_frame = []
+ list_fire = []
+ for t in range(len_time):
+ alpha = alphas[b][t]
+ if alpha + integrate < self.threshold:
+ integrate += alpha
+ list_fire.append(integrate)
+ frames += alpha * hidden[b][t]
+ else:
+ frames += (self.threshold - integrate) * hidden[b][t]
+ list_frame.append(frames)
+ integrate += alpha
+ list_fire.append(integrate)
+ integrate -= self.threshold
+ frames = integrate * hidden[b][t]
+
+ cache_alphas.append(integrate)
+ if integrate > 0.0:
+ cache_hiddens.append(frames / integrate)
+ else:
+ cache_hiddens.append(frames)
+
+ token_length.append(torch.tensor(len(list_frame), device=alphas.device))
+ list_fires.append(list_fire)
+ list_frames.append(list_frame)
+
+ cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
+ cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
+ cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
+ cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
+
+ max_token_len = max(token_length)
+ if max_token_len == 0:
+ return hidden, torch.stack(token_length, 0)
+ list_ls = []
+ for b in range(batch_size):
+ pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
+ if token_length[b] == 0:
+ list_ls.append(pad_frames)
+ else:
+ list_frames[b] = torch.stack(list_frames[b])
+ list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0))
+
+ cache["cif_alphas"] = torch.stack(cache_alphas, axis=0)
+ cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
+ cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
+ cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
+ return torch.stack(list_ls, 0), torch.stack(token_length, 0)
+
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
@@ -407,7 +499,7 @@
fire_place = integrate >= threshold
integrate = torch.where(fire_place,
- integrate - torch.ones([batch_size], device=alphas.device),
+ integrate - torch.ones([batch_size], device=alphas.device)*threshold,
integrate)
fires = torch.stack(list_fires, 1)
@@ -657,3 +749,128 @@
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
+
+class BATPredictor(nn.Module):
+ def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
+ super(BATPredictor, self).__init__()
+
+ self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+ self.cif_output = nn.Linear(idim, 1)
+ self.dropout = torch.nn.Dropout(p=dropout)
+ self.threshold = threshold
+ self.smooth_factor = smooth_factor
+ self.noise_threshold = noise_threshold
+ self.return_accum = return_accum
+
+ def cif(
+ self,
+ input: Tensor,
+ alpha: Tensor,
+ beta: float = 1.0,
+ return_accum: bool = False,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ B, S, C = input.size()
+ assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
+
+ dtype = alpha.dtype
+ alpha = alpha.float()
+
+ alpha_sum = alpha.sum(1)
+ feat_lengths = (alpha_sum / beta).floor().long()
+ T = feat_lengths.max()
+
+ # aggregate and integrate
+ csum = alpha.cumsum(-1)
+ with torch.no_grad():
+ # indices used for scattering
+ right_idx = (csum / beta).floor().long().clip(max=T)
+ left_idx = right_idx.roll(1, dims=1)
+ left_idx[:, 0] = 0
+
+ # count # of fires from each source
+ fire_num = right_idx - left_idx
+ extra_weights = (fire_num - 1).clip(min=0)
+ # The extra entry in last dim is for
+ output = input.new_zeros((B, T + 1, C))
+ source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
+ zero = alpha.new_zeros((1,))
+
+ # right scatter
+ fire_mask = fire_num > 0
+ right_weight = torch.where(
+ fire_mask,
+ csum - right_idx.type_as(alpha) * beta,
+ zero
+ ).type_as(input)
+ # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
+ output.scatter_add_(
+ 1,
+ right_idx.unsqueeze(-1).expand(-1, -1, C),
+ right_weight.unsqueeze(-1) * input
+ )
+
+ # left scatter
+ left_weight = (
+ alpha - right_weight - extra_weights.type_as(alpha) * beta
+ ).type_as(input)
+ output.scatter_add_(
+ 1,
+ left_idx.unsqueeze(-1).expand(-1, -1, C),
+ left_weight.unsqueeze(-1) * input
+ )
+
+ # extra scatters
+ if extra_weights.ge(0).any():
+ extra_steps = extra_weights.max().item()
+ tgt_idx = left_idx
+ src_feats = input * beta
+ for _ in range(extra_steps):
+ tgt_idx = (tgt_idx + 1).clip(max=T)
+ # (B, S, 1)
+ src_mask = (extra_weights > 0)
+ output.scatter_add_(
+ 1,
+ tgt_idx.unsqueeze(-1).expand(-1, -1, C),
+ src_feats * src_mask.unsqueeze(2)
+ )
+ extra_weights -= 1
+
+ output = output[:, :T, :]
+
+ if return_accum:
+ return output, csum
+ else:
+ return output, alpha
+
+ def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ memory = self.cif_conv1d(queries)
+ output = memory + context
+ output = self.dropout(output)
+ output = output.transpose(1, 2)
+ output = torch.relu(output)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ alphas = alphas * mask.transpose(-1, -2).float()
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
+ # logging.info("target_length: {}".format(target_length))
+ else:
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
+ # target_length = length_noise + target_length
+ alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
+ acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
+ return acoustic_embeds, token_num, alphas, cif_peak
--
Gitblit v1.9.1