游雁
2023-12-06 27f31cd42bb4e20dc19de0034fc0d80b449f1db1
funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
"""Network related utility tools."""
import logging
from typing import Dict
from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -57,6 +57,48 @@
    for i in range(n_batch):
        pad[i, : xs[i].size(0)] = xs[i]
    return pad
def pad_list_all_dim(xs, pad_value):
    """Perform padding for the list of tensors.
    Args:
        xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
        pad_value (float): Value for padding.
    Returns:
        Tensor: Padded tensor (B, Tmax, `*`).
    Examples:
        >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
        >>> x
        [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
        >>> pad_list(x, 0)
        tensor([[1., 1., 1., 1.],
                [1., 1., 0., 0.],
                [1., 0., 0., 0.]])
    """
    n_batch = len(xs)
    num_dim = len(xs[0].shape)
    max_len_all_dim = []
    for i in range(num_dim):
        max_len_all_dim.append(max(x.size(i) for x in xs))
    pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
    for i in range(n_batch):
        if num_dim == 1:
            pad[i, : xs[i].size(0)] = xs[i]
        elif num_dim == 2:
            pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
        elif num_dim == 3:
            pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
        else:
            raise ValueError(
                "pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
            )
    return pad
@@ -305,7 +347,7 @@
    Args:
        pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
        pad_targets (LongTensor): Target label tensors (B, Lmax, D).
        pad_targets (LongTensor): Target label tensors (B, Lmax).
        ignore_label (int): Ignore label id.
    Returns:
@@ -407,7 +449,7 @@
    elif mode == "mt" and arch == "rnn":
        # +1 means input (+1) and layers outputs (train_args.elayer)
        subsample = np.ones(train_args.elayers + 1, dtype=np.int)
        subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
        logging.warning("Subsampling is not performed for machine translation.")
        logging.info("subsample: " + " ".join([str(x) for x in subsample]))
        return subsample
@@ -417,7 +459,7 @@
            or (mode == "mt" and arch == "rnn")
            or (mode == "st" and arch == "rnn")
    ):
        subsample = np.ones(train_args.elayers + 1, dtype=np.int)
        subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
        if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
            ss = train_args.subsample.split("_")
            for j in range(min(train_args.elayers + 1, len(ss))):
@@ -432,7 +474,7 @@
    elif mode == "asr" and arch == "rnn_mix":
        subsample = np.ones(
            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int
            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32
        )
        if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
            ss = train_args.subsample.split("_")
@@ -451,7 +493,7 @@
    elif mode == "asr" and arch == "rnn_mulenc":
        subsample_list = []
        for idx in range(train_args.num_encs):
            subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
            subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
            if train_args.etype[idx].endswith("p") and not train_args.etype[
                idx
            ].startswith("vgg"):
@@ -485,14 +527,39 @@
        new_k = k.replace(old_prefix, new_prefix)
        state_dict[new_k] = v
class Swish(torch.nn.Module):
    """Construct an Swish object."""
    """Swish activation definition.
    def forward(self, x):
        """Return Swich activation function."""
        return x * torch.sigmoid(x)
    Swish(x) = (beta * x) * sigmoid(x)
                 where beta = 1 defines standard Swish activation.
    References:
        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
        E-swish variant: https://arxiv.org/abs/1801.07145.
    Args:
        beta: Beta parameter for E-Swish.
                (beta >= 1. If beta < 1, use standard Swish).
        use_builtin: Whether to use PyTorch function if available.
    """
    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
        super().__init__()
        self.beta = beta
        if beta > 1:
            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
        else:
            if use_builtin:
                self.swish = torch.nn.SiLU()
            else:
                self.swish = lambda x: x * torch.sigmoid(x)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward computation."""
        return self.swish(x)
def get_activation(act):
    """Return activation function."""
@@ -506,3 +573,196 @@
    }
    return activation_funcs[act]()
class TooShortUttError(Exception):
    """Raised when the utt is too short for subsampling.
    Args:
        message: Error message to display.
        actual_size: The size that cannot pass the subsampling.
        limit: The size limit for subsampling.
    """
    def __init__(self, message: str, actual_size: int, limit: int) -> None:
        """Construct a TooShortUttError module."""
        super().__init__(message)
        self.actual_size = actual_size
        self.limit = limit
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
    """Check if the input is too short for subsampling.
    Args:
        sub_factor: Subsampling factor for Conv2DSubsampling.
        size: Input size.
    Returns:
        : Whether an error should be sent.
        : Size limit for specified subsampling factor.
    """
    if sub_factor == 2 and size < 3:
        return True, 7
    elif sub_factor == 4 and size < 7:
        return True, 7
    elif sub_factor == 6 and size < 11:
        return True, 11
    return False, -1
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
    """Get conv2D second layer parameters for given subsampling factor.
    Args:
        sub_factor: Subsampling factor (1/X).
        input_size: Input size.
    Returns:
        : Kernel size for second convolution.
        : Stride for second convolution.
        : Conv2DSubsampling output size.
    """
    if sub_factor == 2:
        return 3, 1, (((input_size - 1) // 2 - 2))
    elif sub_factor == 4:
        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
    elif sub_factor == 6:
        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
    else:
        raise ValueError(
            "subsampling_factor parameter should be set to either 2, 4 or 6."
        )
def make_chunk_mask(
    size: int,
    chunk_size: int,
    left_chunk_size: int = 0,
    device: torch.device = None,
) -> torch.Tensor:
    """Create chunk mask for the subsequent steps (size, size).
    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
    Args:
        size: Size of the source mask.
        chunk_size: Number of frames in chunk.
        left_chunk_size: Size of the left context in chunks (0 means full context).
        device: Device for the mask tensor.
    Returns:
        mask: Chunk mask. (size, size)
    """
    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if left_chunk_size < 0:
            start = 0
        else:
            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
        end = min((i // chunk_size + 1) * chunk_size, size)
        mask[i, start:end] = True
    return ~mask
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
    """Create source mask for given lengths.
    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
    Args:
        lengths: Sequence lengths. (B,)
    Returns:
        : Mask for the sequence lengths. (B, max_len)
    """
    max_len = lengths.max()
    batch_size = lengths.size(0)
    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
    return expanded_lengths >= lengths.unsqueeze(1)
def get_transducer_task_io(
    labels: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    ignore_id: int = -1,
    blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Get Transducer loss I/O.
    Args:
        labels: Label ID sequences. (B, L)
        encoder_out_lens: Encoder output lengths. (B,)
        ignore_id: Padding symbol ID.
        blank_id: Blank symbol ID.
    Returns:
        decoder_in: Decoder inputs. (B, U)
        target: Target label ID sequences. (B, U)
        t_len: Time lengths. (B,)
        u_len: Label lengths. (B,)
    """
    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
        """Create padded batch of labels from a list of labels sequences.
        Args:
            labels: Labels sequences. [B x (?)]
            padding_value: Padding value.
        Returns:
            labels: Batch of padded labels sequences. (B,)
        """
        batch_size = len(labels)
        padded = (
            labels[0]
            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
            .fill_(padding_value)
        )
        for i in range(batch_size):
            padded[i, : labels[i].size(0)] = labels[i]
        return padded
    device = labels.device
    labels_unpad = [y[y != ignore_id] for y in labels]
    blank = labels[0].new([blank_id])
    decoder_in = pad_list(
        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
    ).to(device)
    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
    encoder_out_lens = list(map(int, encoder_out_lens))
    t_len = torch.IntTensor(encoder_out_lens).to(device)
    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
    return decoder_in, target, t_len, u_len
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
    if t.size(dim) == pad_len:
        return t
    else:
        pad_size = list(t.shape)
        pad_size[dim] = pad_len - t.size(dim)
        return torch.cat(
            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
        )