| | |
| | | elif isinstance(m, torch.Tensor): |
| | | device = m.device |
| | | else: |
| | | raise TypeError( |
| | | "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}" |
| | | ) |
| | | raise TypeError("Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}") |
| | | return x.to(device) |
| | | |
| | | |
| | |
| | | if length_dim < 0: |
| | | length_dim = xs.dim() + length_dim |
| | | # ind = (:, None, ..., None, :, , None, ..., None) |
| | | ind = tuple( |
| | | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
| | | ) |
| | | ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim())) |
| | | mask = mask[ind].expand_as(xs).to(xs.device) |
| | | return mask |
| | | |
| | |
| | | return ret |
| | | |
| | | |
| | | def th_accuracy(pad_outputs, pad_targets, ignore_label): |
| | | """Calculate accuracy. |
| | | |
| | | Args: |
| | | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). |
| | | pad_targets (LongTensor): Target label tensors (B, Lmax). |
| | | ignore_label (int): Ignore label id. |
| | | |
| | | Returns: |
| | | float: Accuracy value (0.0 - 1.0). |
| | | |
| | | """ |
| | | pad_pred = pad_outputs.view( |
| | | pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) |
| | | ).argmax(2) |
| | | mask = pad_targets != ignore_label |
| | | numerator = torch.sum( |
| | | pad_pred.masked_select(mask) == pad_targets.masked_select(mask) |
| | | ) |
| | | denominator = torch.sum(mask) |
| | | return float(numerator) / float(denominator) |
| | | |
| | | |
| | | def to_torch_tensor(x): |
| | | """Change to torch.Tensor or ComplexTensor from numpy.ndarray. |
| | | |
| | |
| | | return subsample |
| | | |
| | | elif ( |
| | | (mode == "asr" and arch in ("rnn", "rnn-t")) |
| | | or (mode == "mt" and arch == "rnn") |
| | | or (mode == "st" and arch == "rnn") |
| | | (mode == "asr" and arch in ("rnn", "rnn-t")) |
| | | or (mode == "mt" and arch == "rnn") |
| | | or (mode == "st" and arch == "rnn") |
| | | ): |
| | | subsample = np.ones(train_args.elayers + 1, dtype=np.int32) |
| | | if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): |
| | |
| | | return subsample |
| | | |
| | | elif mode == "asr" and arch == "rnn_mix": |
| | | subsample = np.ones( |
| | | train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32 |
| | | ) |
| | | subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32) |
| | | if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): |
| | | ss = train_args.subsample.split("_") |
| | | for j in range( |
| | | min(train_args.elayers_sd + train_args.elayers + 1, len(ss)) |
| | | ): |
| | | for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))): |
| | | subsample[j] = int(ss[j]) |
| | | else: |
| | | logging.warning( |
| | |
| | | subsample_list = [] |
| | | for idx in range(train_args.num_encs): |
| | | subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32) |
| | | if train_args.etype[idx].endswith("p") and not train_args.etype[ |
| | | idx |
| | | ].startswith("vgg"): |
| | | if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"): |
| | | ss = train_args.subsample[idx].split("_") |
| | | for j in range(min(train_args.elayers[idx] + 1, len(ss))): |
| | | subsample[j] = int(ss[j]) |
| | |
| | | raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch)) |
| | | |
| | | |
| | | def rename_state_dict( |
| | | old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor] |
| | | ): |
| | | def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]): |
| | | """Replace keys of old prefix with new prefix in state dict.""" |
| | | # need this list not to break the dict iterator |
| | | old_keys = [k for k in state_dict if k.startswith(old_prefix)] |
| | |
| | | v = state_dict.pop(k) |
| | | new_k = k.replace(old_prefix, new_prefix) |
| | | state_dict[new_k] = v |
| | | |
| | | |
| | | class Swish(torch.nn.Module): |
| | | """Swish activation definition. |
| | |
| | | """Forward computation.""" |
| | | return self.swish(x) |
| | | |
| | | |
| | | def get_activation(act): |
| | | """Return activation function.""" |
| | | |
| | |
| | | } |
| | | |
| | | return activation_funcs[act]() |
| | | |
| | | |
| | | class TooShortUttError(Exception): |
| | | """Raised when the utt is too short for subsampling. |
| | |
| | | elif sub_factor == 6: |
| | | return 5, 3, (((input_size - 1) // 2 - 2) // 3) |
| | | else: |
| | | raise ValueError( |
| | | "subsampling_factor parameter should be set to either 2, 4 or 6." |
| | | ) |
| | | raise ValueError("subsampling_factor parameter should be set to either 2, 4 or 6.") |
| | | |
| | | |
| | | def make_chunk_mask( |
| | |
| | | mask[i, start:end] = True |
| | | |
| | | return ~mask |
| | | |
| | | |
| | | def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: |
| | | """Create source mask for given lengths. |
| | |
| | | |
| | | return decoder_in, target, t_len, u_len |
| | | |
| | | |
| | | def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): |
| | | """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" |
| | | if t.size(dim) == pad_len: |
| | |
| | | else: |
| | | pad_size = list(t.shape) |
| | | pad_size[dim] = pad_len - t.size(dim) |
| | | return torch.cat( |
| | | [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim |
| | | ) |
| | | return torch.cat([t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim) |