aky15
2023-03-21 8a100b731efba8c18f7e7b6cb1cb04ded94248b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""Utility functions for Transducer models."""
 
from typing import List, Tuple
 
import torch
 
 
class TooShortUttError(Exception):
    """Raised when the utt is too short for subsampling.
 
    Args:
        message: Error message to display.
        actual_size: The size that cannot pass the subsampling.
        limit: The size limit for subsampling.
 
    """
 
    def __init__(self, message: str, actual_size: int, limit: int) -> None:
        """Construct a TooShortUttError module."""
        super().__init__(message)
 
        self.actual_size = actual_size
        self.limit = limit
 
 
def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
    """Check if the input is too short for subsampling.
 
    Args:
        sub_factor: Subsampling factor for Conv2DSubsampling.
        size: Input size.
 
    Returns:
        : Whether an error should be sent.
        : Size limit for specified subsampling factor.
 
    """
    if sub_factor == 2 and size < 3:
        return True, 7
    elif sub_factor == 4 and size < 7:
        return True, 7
    elif sub_factor == 6 and size < 11:
        return True, 11
 
    return False, -1
 
 
def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
    """Get conv2D second layer parameters for given subsampling factor.
 
    Args:
        sub_factor: Subsampling factor (1/X).
        input_size: Input size.
 
    Returns:
        : Kernel size for second convolution.
        : Stride for second convolution.
        : Conv2DSubsampling output size.
 
    """
    if sub_factor == 2:
        return 3, 1, (((input_size - 1) // 2 - 2))
    elif sub_factor == 4:
        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
    elif sub_factor == 6:
        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
    else:
        raise ValueError(
            "subsampling_factor parameter should be set to either 2, 4 or 6."
        )
 
 
def make_chunk_mask(
    size: int,
    chunk_size: int,
    left_chunk_size: int = 0,
    device: torch.device = None,
) -> torch.Tensor:
    """Create chunk mask for the subsequent steps (size, size).
 
    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
 
    Args:
        size: Size of the source mask.
        chunk_size: Number of frames in chunk.
        left_chunk_size: Size of the left context in chunks (0 means full context).
        device: Device for the mask tensor.
 
    Returns:
        mask: Chunk mask. (size, size)
 
    """
    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
 
    for i in range(size):
        if left_chunk_size <= 0:
            start = 0
        else:
            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
 
        end = min((i // chunk_size + 1) * chunk_size, size)
        mask[i, start:end] = True
 
    return ~mask
 
 
def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
    """Create source mask for given lengths.
 
    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
 
    Args:
        lengths: Sequence lengths. (B,)
 
    Returns:
        : Mask for the sequence lengths. (B, max_len)
 
    """
    max_len = lengths.max()
    batch_size = lengths.size(0)
 
    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
 
    return expanded_lengths >= lengths.unsqueeze(1)
 
 
def get_transducer_task_io(
    labels: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    ignore_id: int = -1,
    blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Get Transducer loss I/O.
 
    Args:
        labels: Label ID sequences. (B, L)
        encoder_out_lens: Encoder output lengths. (B,)
        ignore_id: Padding symbol ID.
        blank_id: Blank symbol ID.
 
    Returns:
        decoder_in: Decoder inputs. (B, U)
        target: Target label ID sequences. (B, U)
        t_len: Time lengths. (B,)
        u_len: Label lengths. (B,)
 
    """
 
    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
        """Create padded batch of labels from a list of labels sequences.
 
        Args:
            labels: Labels sequences. [B x (?)]
            padding_value: Padding value.
 
        Returns:
            labels: Batch of padded labels sequences. (B,)
 
        """
        batch_size = len(labels)
 
        padded = (
            labels[0]
            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
            .fill_(padding_value)
        )
 
        for i in range(batch_size):
            padded[i, : labels[i].size(0)] = labels[i]
 
        return padded
 
    device = labels.device
 
    labels_unpad = [y[y != ignore_id] for y in labels]
    blank = labels[0].new([blank_id])
 
    decoder_in = pad_list(
        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
    ).to(device)
 
    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
 
    encoder_out_lens = list(map(int, encoder_out_lens))
    t_len = torch.IntTensor(encoder_out_lens).to(device)
 
    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
 
    return decoder_in, target, t_len, u_len
 
def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
    if t.size(dim) == pad_len:
        return t
    else:
        pad_size = list(t.shape)
        pad_size[dim] = pad_len - t.size(dim)
        return torch.cat(
            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
        )