| | |
| | | """Network related utility tools.""" |
| | | |
| | | import logging |
| | | from typing import Dict |
| | | from typing import Dict, List, Tuple |
| | | |
| | | import numpy as np |
| | | import torch |
| | |
| | | |
| | | for i in range(n_batch): |
| | | pad[i, : xs[i].size(0)] = xs[i] |
| | | |
| | | return pad |
| | | |
| | | |
| | | def pad_list_all_dim(xs, pad_value): |
| | | """Perform padding for the list of tensors. |
| | | |
| | | Args: |
| | | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. |
| | | pad_value (float): Value for padding. |
| | | |
| | | Returns: |
| | | Tensor: Padded tensor (B, Tmax, `*`). |
| | | |
| | | Examples: |
| | | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] |
| | | >>> x |
| | | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] |
| | | >>> pad_list(x, 0) |
| | | tensor([[1., 1., 1., 1.], |
| | | [1., 1., 0., 0.], |
| | | [1., 0., 0., 0.]]) |
| | | |
| | | """ |
| | | n_batch = len(xs) |
| | | num_dim = len(xs[0].shape) |
| | | max_len_all_dim = [] |
| | | for i in range(num_dim): |
| | | max_len_all_dim.append(max(x.size(i) for x in xs)) |
| | | pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value) |
| | | |
| | | for i in range(n_batch): |
| | | if num_dim == 1: |
| | | pad[i, : xs[i].size(0)] = xs[i] |
| | | elif num_dim == 2: |
| | | pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i] |
| | | elif num_dim == 3: |
| | | pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i] |
| | | else: |
| | | raise ValueError( |
| | | "pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim) |
| | | ) |
| | | |
| | | return pad |
| | | |
| | |
| | | |
| | | Args: |
| | | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). |
| | | pad_targets (LongTensor): Target label tensors (B, Lmax, D). |
| | | pad_targets (LongTensor): Target label tensors (B, Lmax). |
| | | ignore_label (int): Ignore label id. |
| | | |
| | | Returns: |
| | |
| | | |
| | | elif mode == "mt" and arch == "rnn": |
| | | # +1 means input (+1) and layers outputs (train_args.elayer) |
| | | subsample = np.ones(train_args.elayers + 1, dtype=np.int) |
| | | subsample = np.ones(train_args.elayers + 1, dtype=np.int32) |
| | | logging.warning("Subsampling is not performed for machine translation.") |
| | | logging.info("subsample: " + " ".join([str(x) for x in subsample])) |
| | | return subsample |
| | |
| | | or (mode == "mt" and arch == "rnn") |
| | | or (mode == "st" and arch == "rnn") |
| | | ): |
| | | subsample = np.ones(train_args.elayers + 1, dtype=np.int) |
| | | subsample = np.ones(train_args.elayers + 1, dtype=np.int32) |
| | | if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): |
| | | ss = train_args.subsample.split("_") |
| | | for j in range(min(train_args.elayers + 1, len(ss))): |
| | |
| | | |
| | | elif mode == "asr" and arch == "rnn_mix": |
| | | subsample = np.ones( |
| | | train_args.elayers_sd + train_args.elayers + 1, dtype=np.int |
| | | train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32 |
| | | ) |
| | | if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): |
| | | ss = train_args.subsample.split("_") |
| | |
| | | elif mode == "asr" and arch == "rnn_mulenc": |
| | | subsample_list = [] |
| | | for idx in range(train_args.num_encs): |
| | | subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int) |
| | | subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32) |
| | | if train_args.etype[idx].endswith("p") and not train_args.etype[ |
| | | idx |
| | | ].startswith("vgg"): |
| | |
| | | new_k = k.replace(old_prefix, new_prefix) |
| | | state_dict[new_k] = v |
| | | |
| | | |
| | | class Swish(torch.nn.Module): |
| | | """Construct an Swish object.""" |
| | | """Swish activation definition. |
| | | |
| | | def forward(self, x): |
| | | """Return Swich activation function.""" |
| | | return x * torch.sigmoid(x) |
| | | Swish(x) = (beta * x) * sigmoid(x) |
| | | where beta = 1 defines standard Swish activation. |
| | | |
| | | References: |
| | | https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1. |
| | | E-swish variant: https://arxiv.org/abs/1801.07145. |
| | | |
| | | Args: |
| | | beta: Beta parameter for E-Swish. |
| | | (beta >= 1. If beta < 1, use standard Swish). |
| | | use_builtin: Whether to use PyTorch function if available. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None: |
| | | super().__init__() |
| | | |
| | | self.beta = beta |
| | | |
| | | if beta > 1: |
| | | self.swish = lambda x: (self.beta * x) * torch.sigmoid(x) |
| | | else: |
| | | if use_builtin: |
| | | self.swish = torch.nn.SiLU() |
| | | else: |
| | | self.swish = lambda x: x * torch.sigmoid(x) |
| | | |
| | | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | | """Forward computation.""" |
| | | return self.swish(x) |
| | | |
| | | def get_activation(act): |
| | | """Return activation function.""" |
| | |
| | | } |
| | | |
| | | return activation_funcs[act]() |
| | | |
| | | class TooShortUttError(Exception): |
| | | """Raised when the utt is too short for subsampling. |
| | | |
| | | Args: |
| | | message: Error message to display. |
| | | actual_size: The size that cannot pass the subsampling. |
| | | limit: The size limit for subsampling. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, message: str, actual_size: int, limit: int) -> None: |
| | | """Construct a TooShortUttError module.""" |
| | | super().__init__(message) |
| | | |
| | | self.actual_size = actual_size |
| | | self.limit = limit |
| | | |
| | | |
| | | def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: |
| | | """Check if the input is too short for subsampling. |
| | | |
| | | Args: |
| | | sub_factor: Subsampling factor for Conv2DSubsampling. |
| | | size: Input size. |
| | | |
| | | Returns: |
| | | : Whether an error should be sent. |
| | | : Size limit for specified subsampling factor. |
| | | |
| | | """ |
| | | if sub_factor == 2 and size < 3: |
| | | return True, 7 |
| | | elif sub_factor == 4 and size < 7: |
| | | return True, 7 |
| | | elif sub_factor == 6 and size < 11: |
| | | return True, 11 |
| | | |
| | | return False, -1 |
| | | |
| | | |
| | | def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: |
| | | """Get conv2D second layer parameters for given subsampling factor. |
| | | |
| | | Args: |
| | | sub_factor: Subsampling factor (1/X). |
| | | input_size: Input size. |
| | | |
| | | Returns: |
| | | : Kernel size for second convolution. |
| | | : Stride for second convolution. |
| | | : Conv2DSubsampling output size. |
| | | |
| | | """ |
| | | if sub_factor == 2: |
| | | return 3, 1, (((input_size - 1) // 2 - 2)) |
| | | elif sub_factor == 4: |
| | | return 3, 2, (((input_size - 1) // 2 - 1) // 2) |
| | | elif sub_factor == 6: |
| | | return 5, 3, (((input_size - 1) // 2 - 2) // 3) |
| | | else: |
| | | raise ValueError( |
| | | "subsampling_factor parameter should be set to either 2, 4 or 6." |
| | | ) |
| | | |
| | | |
| | | def make_chunk_mask( |
| | | size: int, |
| | | chunk_size: int, |
| | | left_chunk_size: int = 0, |
| | | device: torch.device = None, |
| | | ) -> torch.Tensor: |
| | | """Create chunk mask for the subsequent steps (size, size). |
| | | |
| | | Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py |
| | | |
| | | Args: |
| | | size: Size of the source mask. |
| | | chunk_size: Number of frames in chunk. |
| | | left_chunk_size: Size of the left context in chunks (0 means full context). |
| | | device: Device for the mask tensor. |
| | | |
| | | Returns: |
| | | mask: Chunk mask. (size, size) |
| | | |
| | | """ |
| | | mask = torch.zeros(size, size, device=device, dtype=torch.bool) |
| | | |
| | | for i in range(size): |
| | | if left_chunk_size < 0: |
| | | start = 0 |
| | | else: |
| | | start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) |
| | | |
| | | end = min((i // chunk_size + 1) * chunk_size, size) |
| | | mask[i, start:end] = True |
| | | |
| | | return ~mask |
| | | |
| | | def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: |
| | | """Create source mask for given lengths. |
| | | |
| | | Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py |
| | | |
| | | Args: |
| | | lengths: Sequence lengths. (B,) |
| | | |
| | | Returns: |
| | | : Mask for the sequence lengths. (B, max_len) |
| | | |
| | | """ |
| | | max_len = lengths.max() |
| | | batch_size = lengths.size(0) |
| | | |
| | | expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) |
| | | |
| | | return expanded_lengths >= lengths.unsqueeze(1) |
| | | |
| | | |
| | | def get_transducer_task_io( |
| | | labels: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Get Transducer loss I/O. |
| | | |
| | | Args: |
| | | labels: Label ID sequences. (B, L) |
| | | encoder_out_lens: Encoder output lengths. (B,) |
| | | ignore_id: Padding symbol ID. |
| | | blank_id: Blank symbol ID. |
| | | |
| | | Returns: |
| | | decoder_in: Decoder inputs. (B, U) |
| | | target: Target label ID sequences. (B, U) |
| | | t_len: Time lengths. (B,) |
| | | u_len: Label lengths. (B,) |
| | | |
| | | """ |
| | | |
| | | def pad_list(labels: List[torch.Tensor], padding_value: int = 0): |
| | | """Create padded batch of labels from a list of labels sequences. |
| | | |
| | | Args: |
| | | labels: Labels sequences. [B x (?)] |
| | | padding_value: Padding value. |
| | | |
| | | Returns: |
| | | labels: Batch of padded labels sequences. (B,) |
| | | |
| | | """ |
| | | batch_size = len(labels) |
| | | |
| | | padded = ( |
| | | labels[0] |
| | | .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) |
| | | .fill_(padding_value) |
| | | ) |
| | | |
| | | for i in range(batch_size): |
| | | padded[i, : labels[i].size(0)] = labels[i] |
| | | |
| | | return padded |
| | | |
| | | device = labels.device |
| | | |
| | | labels_unpad = [y[y != ignore_id] for y in labels] |
| | | blank = labels[0].new([blank_id]) |
| | | |
| | | decoder_in = pad_list( |
| | | [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id |
| | | ).to(device) |
| | | |
| | | target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) |
| | | |
| | | encoder_out_lens = list(map(int, encoder_out_lens)) |
| | | t_len = torch.IntTensor(encoder_out_lens).to(device) |
| | | |
| | | u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) |
| | | |
| | | return decoder_in, target, t_len, u_len |
| | | |
| | | def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): |
| | | """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" |
| | | if t.size(dim) == pad_len: |
| | | return t |
| | | else: |
| | | pad_size = list(t.shape) |
| | | pad_size[dim] = pad_len - t.size(dim) |
| | | return torch.cat( |
| | | [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim |
| | | ) |