| | |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | |
| | | import re |
| | | import torch |
| | | import numpy as np |
| | | import yaml |
| | | try: |
| | |
| | | logger_initialized = {} |
| | | |
| | | |
| | | def pad_list(xs, pad_value, max_len=None): |
| | | n_batch = len(xs) |
| | | if max_len is None: |
| | | max_len = max(x.size(0) for x in xs) |
| | | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) |
| | | |
| | | for i in range(n_batch): |
| | | pad[i, : xs[i].size(0)] = xs[i] |
| | | |
| | | return pad |
| | | |
| | | |
| | | def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): |
| | | if length_dim == 0: |
| | | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
| | | |
| | | if not isinstance(lengths, list): |
| | | lengths = lengths.tolist() |
| | | bs = int(len(lengths)) |
| | | if maxlen is None: |
| | | if xs is None: |
| | | maxlen = int(max(lengths)) |
| | | else: |
| | | maxlen = xs.size(length_dim) |
| | | else: |
| | | assert xs is None |
| | | assert maxlen >= int(max(lengths)) |
| | | |
| | | seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
| | | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
| | | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
| | | mask = seq_range_expand >= seq_length_expand |
| | | |
| | | if xs is not None: |
| | | assert xs.size(0) == bs, (xs.size(0), bs) |
| | | |
| | | if length_dim < 0: |
| | | length_dim = xs.dim() + length_dim |
| | | # ind = (:, None, ..., None, :, , None, ..., None) |
| | | ind = tuple( |
| | | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
| | | ) |
| | | mask = mask[ind].expand_as(xs).to(xs.device) |
| | | return mask |
| | | |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_list: Union[List, str], |
| | | ): |