游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/modules/nets_utils.py
@@ -61,6 +61,48 @@
    return pad
def pad_list_all_dim(xs, pad_value):
    """Perform padding for the list of tensors.
    Args:
        xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
        pad_value (float): Value for padding.
    Returns:
        Tensor: Padded tensor (B, Tmax, `*`).
    Examples:
        >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
        >>> x
        [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
        >>> pad_list(x, 0)
        tensor([[1., 1., 1., 1.],
                [1., 1., 0., 0.],
                [1., 0., 0., 0.]])
    """
    n_batch = len(xs)
    num_dim = len(xs[0].shape)
    max_len_all_dim = []
    for i in range(num_dim):
        max_len_all_dim.append(max(x.size(i) for x in xs))
    pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
    for i in range(n_batch):
        if num_dim == 1:
            pad[i, : xs[i].size(0)] = xs[i]
        elif num_dim == 2:
            pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
        elif num_dim == 3:
            pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
        else:
            raise ValueError(
                "pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
            )
    return pad
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
    """Make mask tensor containing indices of padded part.
@@ -407,7 +449,7 @@
    elif mode == "mt" and arch == "rnn":
        # +1 means input (+1) and layers outputs (train_args.elayer)
        subsample = np.ones(train_args.elayers + 1, dtype=np.int)
        subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
        logging.warning("Subsampling is not performed for machine translation.")
        logging.info("subsample: " + " ".join([str(x) for x in subsample]))
        return subsample
@@ -417,7 +459,7 @@
            or (mode == "mt" and arch == "rnn")
            or (mode == "st" and arch == "rnn")
    ):
        subsample = np.ones(train_args.elayers + 1, dtype=np.int)
        subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
        if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
            ss = train_args.subsample.split("_")
            for j in range(min(train_args.elayers + 1, len(ss))):
@@ -432,7 +474,7 @@
    elif mode == "asr" and arch == "rnn_mix":
        subsample = np.ones(
            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int
            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32
        )
        if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
            ss = train_args.subsample.split("_")
@@ -451,7 +493,7 @@
    elif mode == "asr" and arch == "rnn_mulenc":
        subsample_list = []
        for idx in range(train_args.num_encs):
            subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
            subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
            if train_args.etype[idx].endswith("p") and not train_args.etype[
                idx
            ].startswith("vgg"):
@@ -485,14 +527,39 @@
        new_k = k.replace(old_prefix, new_prefix)
        state_dict[new_k] = v
class Swish(torch.nn.Module):
    """Construct an Swish object."""
    """Swish activation definition.
    def forward(self, x):
        """Return Swich activation function."""
        return x * torch.sigmoid(x)
    Swish(x) = (beta * x) * sigmoid(x)
                 where beta = 1 defines standard Swish activation.
    References:
        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
        E-swish variant: https://arxiv.org/abs/1801.07145.
    Args:
        beta: Beta parameter for E-Swish.
                (beta >= 1. If beta < 1, use standard Swish).
        use_builtin: Whether to use PyTorch function if available.
    """
    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
        super().__init__()
        self.beta = beta
        if beta > 1:
            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
        else:
            if use_builtin:
                self.swish = torch.nn.SiLU()
            else:
                self.swish = lambda x: x * torch.sigmoid(x)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward computation."""
        return self.swish(x)
def get_activation(act):
    """Return activation function."""
@@ -595,7 +662,7 @@
    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if left_chunk_size <= 0:
        if left_chunk_size < 0:
            start = 0
        else:
            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)