游雁
2024-01-13 bdfd27b9e96bd55c449953bb577e1d4deeaf11c9
funasr/models/paraformer/cif_predictor.py
@@ -8,9 +8,9 @@
from funasr.models.scama.utils import sequence_mask
from typing import Optional, Tuple
from funasr.utils.register import register_class, registry_tables
from funasr.register import tables
@register_class("predictor_classes", "CifPredictor")
@tables.register("predictor_classes", "CifPredictor")
class CifPredictor(nn.Module):
    def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
        super().__init__()
@@ -136,7 +136,7 @@
        predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
        return predictor_alignments.detach(), predictor_alignments_length.detach()
@register_class("predictor_classes", "CifPredictorV2")
@tables.register("predictor_classes", "CifPredictorV2")
class CifPredictorV2(nn.Module):
    def __init__(self,
                 idim,
@@ -205,7 +205,8 @@
        return acoustic_embeds, token_num, alphas, cif_peak
    def forward_chunk(self, hidden, cache=None):
    def forward_chunk(self, hidden, cache=None, **kwargs):
        is_final = kwargs.get("is_final", False)
        batch_size, len_time, hidden_size = hidden.shape
        h = hidden
        context = h.transpose(1, 2)
@@ -226,14 +227,14 @@
        if cache is not None and "chunk_size" in cache:
            alphas[:, :cache["chunk_size"][0]] = 0.0
            if "is_final" in cache and not cache["is_final"]:
            if not is_final:
                alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
        if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
            cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
            cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
            hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
            alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
        if cache is not None and "is_final" in cache and cache["is_final"]:
        if cache is not None and is_final:
            tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
            tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
            tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
@@ -277,7 +278,7 @@
        max_token_len = max(token_length)
        if max_token_len == 0:
             return hidden, torch.stack(token_length, 0)
             return hidden, torch.stack(token_length, 0), None, None
        list_ls = []
        for b in range(batch_size):
            pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
@@ -291,7 +292,7 @@
        cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
        cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
        cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
        return torch.stack(list_ls, 0), torch.stack(token_length, 0)
        return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):