#!/usr/bin/env python3
|
# -*- encoding: utf-8 -*-
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
# Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved.
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
import torch
|
import logging
|
import numpy as np
|
|
from funasr.register import tables
|
from funasr.train_utils.device_funcs import to_device
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
from torch.cuda.amp import autocast
|
|
|
@tables.register("predictor_classes", "PifPredictor")
|
class PifPredictor(torch.nn.Module):
|
"""
|
Author: Kun Zou, chinazoukun@gmail.com
|
E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
|
https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
|
"""
|
def __init__(
|
self,
|
idim,
|
l_order,
|
r_order,
|
threshold=1.0,
|
dropout=0.1,
|
smooth_factor=1.0,
|
noise_threshold=0,
|
sigma=0.5,
|
bias=0.0,
|
sigma_heads=4,
|
):
|
super().__init__()
|
|
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
|
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
|
self.cif_output = torch.nn.Linear(idim, 1)
|
self.dropout = torch.nn.Dropout(p=dropout)
|
self.threshold = threshold
|
self.smooth_factor = smooth_factor
|
self.noise_threshold = noise_threshold
|
self.sigma = torch.nn.Parameter(torch.tensor([sigma]*sigma_heads))
|
self.bias = torch.nn.Parameter(torch.tensor([bias]*sigma_heads))
|
self.sigma_heads = sigma_heads
|
|
def forward(
|
self,
|
hidden,
|
target_label=None,
|
mask=None,
|
ignore_id=-1,
|
mask_chunk_predictor=None,
|
target_label_length=None,
|
):
|
|
with autocast(False):
|
h = hidden
|
context = h.transpose(1, 2)
|
queries = self.pad(context)
|
memory = self.cif_conv1d(queries)
|
output = memory + context
|
output = self.dropout(output)
|
output = output.transpose(1, 2)
|
output = torch.relu(output)
|
output = self.cif_output(output)
|
alphas = torch.sigmoid(output)
|
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
if mask is not None:
|
mask = mask.transpose(-1, -2).float()
|
alphas = alphas * mask
|
if mask_chunk_predictor is not None:
|
alphas = alphas * mask_chunk_predictor
|
alphas = alphas.squeeze(-1)
|
mask = mask.squeeze(-1)
|
if target_label_length is not None:
|
target_length = target_label_length
|
elif target_label is not None:
|
target_mask = (target_label != ignore_id).float()
|
target_length = target_mask.sum(-1)
|
else:
|
target_mask = None
|
target_length = None
|
token_num = alphas.sum(-1)
|
if target_length is not None:
|
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
max_token_num = torch.max(target_length)
|
else:
|
token_num_int = token_num.round()
|
alphas *=(token_num_int / token_num)[:, None]
|
max_token_num = torch.max(token_num_int)
|
alignment = torch.cumsum(alphas, dim=-1)
|
fire_positions = (torch.arange(max_token_num) + 0.5).type_as(alphas).unsqueeze(0)
|
scores = - ((fire_positions[:, None, :, None] - alignment[:, None, None, :]) * self.sigma[None, :, None, None]) **2 + self.bias[None, :, None, None]
|
scores = scores.masked_fill(~(mask[:, None, None, :].to(torch.bool)), float("-inf"))
|
weights = torch.softmax(scores, dim=-1)
|
n_hidden = hidden.view(hidden.size(0), -1, self.sigma_heads, hidden.size(-1) // self.sigma_heads).transpose(1, 2)
|
acoustic_embeds = torch.matmul(weights, n_hidden).transpose(1,2).contiguous().view(hidden.size(0), -1, hidden.size(-1))
|
|
if target_mask is not None:
|
acoustic_embeds *= target_mask[:, :, None]
|
cif_peak = None
|
return acoustic_embeds, token_num, alphas, cif_peak
|
|