From a308356d5c7c165bc4c9f8732f1fe920dcd4b67a Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期一, 03 七月 2023 15:02:11 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/models/predictor/cif.py | 127 ++++++++++++++++++++++++++++++++++++++++++
1 files changed, 127 insertions(+), 0 deletions(-)
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index 3c363db..c66af94 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -1,10 +1,12 @@
import torch
from torch import nn
+from torch import Tensor
import logging
import numpy as np
from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.streaming_utils.utils import sequence_mask
+from typing import Optional, Tuple
class CifPredictor(nn.Module):
def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
@@ -747,3 +749,128 @@
predictor_alignments = index_div_bool_zeros_count_tile_out
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
+
+class BATPredictor(nn.Module):
+ def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
+ super(BATPredictor, self).__init__()
+
+ self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+ self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+ self.cif_output = nn.Linear(idim, 1)
+ self.dropout = torch.nn.Dropout(p=dropout)
+ self.threshold = threshold
+ self.smooth_factor = smooth_factor
+ self.noise_threshold = noise_threshold
+ self.return_accum = return_accum
+
+ def cif(
+ self,
+ input: Tensor,
+ alpha: Tensor,
+ beta: float = 1.0,
+ return_accum: bool = False,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ B, S, C = input.size()
+ assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
+
+ dtype = alpha.dtype
+ alpha = alpha.float()
+
+ alpha_sum = alpha.sum(1)
+ feat_lengths = (alpha_sum / beta).floor().long()
+ T = feat_lengths.max()
+
+ # aggregate and integrate
+ csum = alpha.cumsum(-1)
+ with torch.no_grad():
+ # indices used for scattering
+ right_idx = (csum / beta).floor().long().clip(max=T)
+ left_idx = right_idx.roll(1, dims=1)
+ left_idx[:, 0] = 0
+
+ # count # of fires from each source
+ fire_num = right_idx - left_idx
+ extra_weights = (fire_num - 1).clip(min=0)
+ # The extra entry in last dim is for
+ output = input.new_zeros((B, T + 1, C))
+ source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
+ zero = alpha.new_zeros((1,))
+
+ # right scatter
+ fire_mask = fire_num > 0
+ right_weight = torch.where(
+ fire_mask,
+ csum - right_idx.type_as(alpha) * beta,
+ zero
+ ).type_as(input)
+ # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
+ output.scatter_add_(
+ 1,
+ right_idx.unsqueeze(-1).expand(-1, -1, C),
+ right_weight.unsqueeze(-1) * input
+ )
+
+ # left scatter
+ left_weight = (
+ alpha - right_weight - extra_weights.type_as(alpha) * beta
+ ).type_as(input)
+ output.scatter_add_(
+ 1,
+ left_idx.unsqueeze(-1).expand(-1, -1, C),
+ left_weight.unsqueeze(-1) * input
+ )
+
+ # extra scatters
+ if extra_weights.ge(0).any():
+ extra_steps = extra_weights.max().item()
+ tgt_idx = left_idx
+ src_feats = input * beta
+ for _ in range(extra_steps):
+ tgt_idx = (tgt_idx + 1).clip(max=T)
+ # (B, S, 1)
+ src_mask = (extra_weights > 0)
+ output.scatter_add_(
+ 1,
+ tgt_idx.unsqueeze(-1).expand(-1, -1, C),
+ src_feats * src_mask.unsqueeze(2)
+ )
+ extra_weights -= 1
+
+ output = output[:, :T, :]
+
+ if return_accum:
+ return output, csum
+ else:
+ return output, alpha
+
+ def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ memory = self.cif_conv1d(queries)
+ output = memory + context
+ output = self.dropout(output)
+ output = output.transpose(1, 2)
+ output = torch.relu(output)
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
+ if mask is not None:
+ alphas = alphas * mask.transpose(-1, -2).float()
+ if mask_chunk_predictor is not None:
+ alphas = alphas * mask_chunk_predictor
+ alphas = alphas.squeeze(-1)
+ if target_label_length is not None:
+ target_length = target_label_length
+ elif target_label is not None:
+ target_length = (target_label != ignore_id).float().sum(-1)
+ # logging.info("target_length: {}".format(target_length))
+ else:
+ target_length = None
+ token_num = alphas.sum(-1)
+ if target_length is not None:
+ # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
+ # target_length = length_noise + target_length
+ alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
+ acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
+ return acoustic_embeds, token_num, alphas, cif_peak
--
Gitblit v1.9.1