kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
 
import torch
import logging
import numpy as np
 
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from torch.cuda.amp import autocast
 
 
@tables.register("predictor_classes", "PifPredictor")
class PifPredictor(torch.nn.Module):
    """
    Author: Kun Zou, chinazoukun@gmail.com
    E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
    https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
    """
    def __init__(
        self,
        idim,
        l_order,
        r_order,
        threshold=1.0,
        dropout=0.1,
        smooth_factor=1.0,
        noise_threshold=0,
        sigma=0.5,
        bias=0.0,
        sigma_heads=4,
    ):
        super().__init__()
 
        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
        self.cif_output = torch.nn.Linear(idim, 1)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.threshold = threshold
        self.smooth_factor = smooth_factor
        self.noise_threshold = noise_threshold
        self.sigma = torch.nn.Parameter(torch.tensor([sigma]*sigma_heads))
        self.bias = torch.nn.Parameter(torch.tensor([bias]*sigma_heads))
        self.sigma_heads = sigma_heads
 
    def forward(
        self,
        hidden,
        target_label=None,
        mask=None,
        ignore_id=-1,
        mask_chunk_predictor=None,
        target_label_length=None,
    ):
 
        with autocast(False):
            h = hidden
            context = h.transpose(1, 2)
            queries = self.pad(context)
            memory = self.cif_conv1d(queries)
            output = memory + context
            output = self.dropout(output)
            output = output.transpose(1, 2)
            output = torch.relu(output)
            output = self.cif_output(output)
            alphas = torch.sigmoid(output)
            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
            if mask is not None:
                mask = mask.transpose(-1, -2).float()
                alphas = alphas * mask
            if mask_chunk_predictor is not None:
                alphas = alphas * mask_chunk_predictor
            alphas = alphas.squeeze(-1)
            mask = mask.squeeze(-1)
            if target_label_length is not None:
                target_length = target_label_length
            elif target_label is not None:
                target_mask = (target_label != ignore_id).float()
                target_length = target_mask.sum(-1)
            else:
                target_mask = None
                target_length = None
            token_num = alphas.sum(-1)
            if target_length is not None:
                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
                max_token_num = torch.max(target_length)
            else:
                token_num_int = token_num.round()
                alphas *=(token_num_int / token_num)[:, None]
                max_token_num = torch.max(token_num_int)
            alignment = torch.cumsum(alphas, dim=-1)
            fire_positions = (torch.arange(max_token_num) + 0.5).type_as(alphas).unsqueeze(0)
            scores = - ((fire_positions[:, None, :, None] - alignment[:, None, None, :]) * self.sigma[None, :, None, None]) **2 + self.bias[None, :, None, None]
            scores = scores.masked_fill(~(mask[:, None, None, :].to(torch.bool)), float("-inf"))
            weights = torch.softmax(scores, dim=-1)
            n_hidden = hidden.view(hidden.size(0), -1, self.sigma_heads, hidden.size(-1) // self.sigma_heads).transpose(1, 2)
            acoustic_embeds = torch.matmul(weights, n_hidden).transpose(1,2).contiguous().view(hidden.size(0), -1, hidden.size(-1))
 
            if target_mask is not None:
                acoustic_embeds *= target_mask[:, :, None]
            cif_peak = None
        return acoustic_embeds, token_num, alphas, cif_peak