From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/modules/nets_utils.py |  280 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 270 insertions(+), 10 deletions(-)

diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 6d77d69..b1879fa 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
 """Network related utility tools."""
 
 import logging
-from typing import Dict
+from typing import Dict, List, Tuple
 
 import numpy as np
 import torch
@@ -57,6 +57,48 @@
 
     for i in range(n_batch):
         pad[i, : xs[i].size(0)] = xs[i]
+
+    return pad
+
+
+def pad_list_all_dim(xs, pad_value):
+    """Perform padding for the list of tensors.
+
+    Args:
+        xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+        pad_value (float): Value for padding.
+
+    Returns:
+        Tensor: Padded tensor (B, Tmax, `*`).
+
+    Examples:
+        >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+        >>> x
+        [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+        >>> pad_list(x, 0)
+        tensor([[1., 1., 1., 1.],
+                [1., 1., 0., 0.],
+                [1., 0., 0., 0.]])
+
+    """
+    n_batch = len(xs)
+    num_dim = len(xs[0].shape)
+    max_len_all_dim = []
+    for i in range(num_dim):
+        max_len_all_dim.append(max(x.size(i) for x in xs))
+    pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
+
+    for i in range(n_batch):
+        if num_dim == 1:
+            pad[i, : xs[i].size(0)] = xs[i]
+        elif num_dim == 2:
+            pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
+        elif num_dim == 3:
+            pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
+        else:
+            raise ValueError(
+                "pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
+            )
 
     return pad
 
@@ -407,7 +449,7 @@
 
     elif mode == "mt" and arch == "rnn":
         # +1 means input (+1) and layers outputs (train_args.elayer)
-        subsample = np.ones(train_args.elayers + 1, dtype=np.int)
+        subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
         logging.warning("Subsampling is not performed for machine translation.")
         logging.info("subsample: " + " ".join([str(x) for x in subsample]))
         return subsample
@@ -417,7 +459,7 @@
             or (mode == "mt" and arch == "rnn")
             or (mode == "st" and arch == "rnn")
     ):
-        subsample = np.ones(train_args.elayers + 1, dtype=np.int)
+        subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
         if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
             ss = train_args.subsample.split("_")
             for j in range(min(train_args.elayers + 1, len(ss))):
@@ -432,7 +474,7 @@
 
     elif mode == "asr" and arch == "rnn_mix":
         subsample = np.ones(
-            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int
+            train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32
         )
         if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
             ss = train_args.subsample.split("_")
@@ -451,7 +493,7 @@
     elif mode == "asr" and arch == "rnn_mulenc":
         subsample_list = []
         for idx in range(train_args.num_encs):
-            subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
+            subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
             if train_args.etype[idx].endswith("p") and not train_args.etype[
                 idx
             ].startswith("vgg"):
@@ -485,14 +527,39 @@
         new_k = k.replace(old_prefix, new_prefix)
         state_dict[new_k] = v
 
-
 class Swish(torch.nn.Module):
-    """Construct an Swish object."""
+    """Swish activation definition.
 
-    def forward(self, x):
-        """Return Swich activation function."""
-        return x * torch.sigmoid(x)
+    Swish(x) = (beta * x) * sigmoid(x)
+                 where beta = 1 defines standard Swish activation.
 
+    References:
+        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
+        E-swish variant: https://arxiv.org/abs/1801.07145.
+
+    Args:
+        beta: Beta parameter for E-Swish.
+                (beta >= 1. If beta < 1, use standard Swish).
+        use_builtin: Whether to use PyTorch function if available.
+
+    """
+
+    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
+        super().__init__()
+
+        self.beta = beta
+
+        if beta > 1:
+            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
+        else:
+            if use_builtin:
+                self.swish = torch.nn.SiLU()
+            else:
+                self.swish = lambda x: x * torch.sigmoid(x)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward computation."""
+        return self.swish(x)
 
 def get_activation(act):
     """Return activation function."""
@@ -506,3 +573,196 @@
     }
 
     return activation_funcs[act]()
+
+class TooShortUttError(Exception):
+    """Raised when the utt is too short for subsampling.
+
+    Args:
+        message: Error message to display.
+        actual_size: The size that cannot pass the subsampling.
+        limit: The size limit for subsampling.
+
+    """
+
+    def __init__(self, message: str, actual_size: int, limit: int) -> None:
+        """Construct a TooShortUttError module."""
+        super().__init__(message)
+
+        self.actual_size = actual_size
+        self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+    """Check if the input is too short for subsampling.
+
+    Args:
+        sub_factor: Subsampling factor for Conv2DSubsampling.
+        size: Input size.
+
+    Returns:
+        : Whether an error should be sent.
+        : Size limit for specified subsampling factor.
+
+    """
+    if sub_factor == 2 and size < 3:
+        return True, 7
+    elif sub_factor == 4 and size < 7:
+        return True, 7
+    elif sub_factor == 6 and size < 11:
+        return True, 11
+
+    return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+    """Get conv2D second layer parameters for given subsampling factor.
+
+    Args:
+        sub_factor: Subsampling factor (1/X).
+        input_size: Input size.
+
+    Returns:
+        : Kernel size for second convolution.
+        : Stride for second convolution.
+        : Conv2DSubsampling output size.
+
+    """
+    if sub_factor == 2:
+        return 3, 1, (((input_size - 1) // 2 - 2))
+    elif sub_factor == 4:
+        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+    elif sub_factor == 6:
+        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+    else:
+        raise ValueError(
+            "subsampling_factor parameter should be set to either 2, 4 or 6."
+        )
+
+
+def make_chunk_mask(
+    size: int,
+    chunk_size: int,
+    left_chunk_size: int = 0,
+    device: torch.device = None,
+) -> torch.Tensor:
+    """Create chunk mask for the subsequent steps (size, size).
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        size: Size of the source mask.
+        chunk_size: Number of frames in chunk.
+        left_chunk_size: Size of the left context in chunks (0 means full context).
+        device: Device for the mask tensor.
+
+    Returns:
+        mask: Chunk mask. (size, size)
+
+    """
+    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+    for i in range(size):
+        if left_chunk_size < 0:
+            start = 0
+        else:
+            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+        end = min((i // chunk_size + 1) * chunk_size, size)
+        mask[i, start:end] = True
+
+    return ~mask
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+    """Create source mask for given lengths.
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        lengths: Sequence lengths. (B,)
+
+    Returns:
+        : Mask for the sequence lengths. (B, max_len)
+
+    """
+    max_len = lengths.max()
+    batch_size = lengths.size(0)
+
+    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+    return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+    labels: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+    ignore_id: int = -1,
+    blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Get Transducer loss I/O.
+
+    Args:
+        labels: Label ID sequences. (B, L)
+        encoder_out_lens: Encoder output lengths. (B,)
+        ignore_id: Padding symbol ID.
+        blank_id: Blank symbol ID.
+
+    Returns:
+        decoder_in: Decoder inputs. (B, U)
+        target: Target label ID sequences. (B, U)
+        t_len: Time lengths. (B,)
+        u_len: Label lengths. (B,)
+
+    """
+
+    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+        """Create padded batch of labels from a list of labels sequences.
+
+        Args:
+            labels: Labels sequences. [B x (?)]
+            padding_value: Padding value.
+
+        Returns:
+            labels: Batch of padded labels sequences. (B,)
+
+        """
+        batch_size = len(labels)
+
+        padded = (
+            labels[0]
+            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+            .fill_(padding_value)
+        )
+
+        for i in range(batch_size):
+            padded[i, : labels[i].size(0)] = labels[i]
+
+        return padded
+
+    device = labels.device
+
+    labels_unpad = [y[y != ignore_id] for y in labels]
+    blank = labels[0].new([blank_id])
+
+    decoder_in = pad_list(
+        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+    ).to(device)
+
+    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+    encoder_out_lens = list(map(int, encoder_out_lens))
+    t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+    return decoder_in, target, t_len, u_len
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+    if t.size(dim) == pad_len:
+        return t
+    else:
+        pad_size = list(t.shape)
+        pad_size[dim] = pad_len - t.size(dim)
+        return torch.cat(
+            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
+        )

--
Gitblit v1.9.1