aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""Conv1d block for Transducer encoder."""
 
from typing import Optional, Tuple, Union
 
import torch
 
 
class Conv1d(torch.nn.Module):
    """Conv1d module definition.
 
    Args:
        input_size: Input dimension.
        output_size: Output dimension.
        kernel_size: Size of the convolving kernel.
        stride: Stride of the convolution.
        dilation: Spacing between the kernel points.
        groups: Number of blocked connections from input channels to output channels.
        bias: Whether to add a learnable bias to the output.
        batch_norm: Whether to use batch normalization after convolution.
        relu: Whether to use a ReLU activation after convolution.
        causal: Whether to use causal convolution (set to True if streaming).
        dropout_rate: Dropout rate.
 
    """
 
    def __init__(
        self,
        input_size: int,
        output_size: int,
        kernel_size: Union[int, Tuple],
        stride: Union[int, Tuple] = 1,
        dilation: Union[int, Tuple] = 1,
        groups: Union[int, Tuple] = 1,
        bias: bool = True,
        batch_norm: bool = False,
        relu: bool = True,
        causal: bool = False,
        dropout_rate: float = 0.0,
    ) -> None:
        """Construct a Conv1d object."""
        super().__init__()
 
        if causal:
            self.lorder = kernel_size - 1
            stride = 1
        else:
            self.lorder = 0
            stride = stride
 
        self.conv = torch.nn.Conv1d(
            input_size,
            output_size,
            kernel_size,
            stride=stride,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )
 
        self.dropout = torch.nn.Dropout(p=dropout_rate)
 
        if relu:
            self.relu_func = torch.nn.ReLU()
 
        if batch_norm:
            self.bn = torch.nn.BatchNorm1d(output_size)
 
        self.out_pos = torch.nn.Linear(input_size, output_size)
 
        self.input_size = input_size
        self.output_size = output_size
 
        self.relu = relu
        self.batch_norm = batch_norm
        self.causal = causal
 
        self.kernel_size = kernel_size
        self.padding = dilation * (kernel_size - 1)
        self.stride = stride
 
        self.cache = None
 
    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
        """Initialize/Reset Conv1d cache for streaming.
 
        Args:
            left_context: Number of left frames during chunk-by-chunk inference.
            device: Device to use for cache tensor.
 
        """
        self.cache = torch.zeros(
            (1, self.input_size, self.kernel_size - 1), device=device
        )
 
    def forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        chunk_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Encode input sequences.
 
        Args:
            x: Conv1d input sequences. (B, T, D_in)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
            mask: Source mask. (B, T)
            chunk_mask: Chunk mask. (T_2, T_2)
 
        Returns:
            x: Conv1d output sequences. (B, sub(T), D_out)
            mask: Source mask. (B, T) or (B, sub(T))
            pos_enc: Positional embedding sequences.
                       (B, 2 * (T - 1), D_att) or (B, 2 * (sub(T) - 1), D_out)
 
        """
        x = x.transpose(1, 2)
 
        if self.lorder > 0:
            x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
        else:
            mask = self.create_new_mask(mask)
            pos_enc = self.create_new_pos_enc(pos_enc)
 
        x = self.conv(x)
 
        if self.batch_norm:
            x = self.bn(x)
 
        x = self.dropout(x)
 
        if self.relu:
            x = self.relu_func(x)
 
        x = x.transpose(1, 2)
 
        return x, mask, self.out_pos(pos_enc)
 
    def chunk_forward(
        self,
        x: torch.Tensor,
        pos_enc: torch.Tensor,
        mask: torch.Tensor,
        left_context: int = 0,
        right_context: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode chunk of input sequence.
 
        Args:
            x: Conv1d input sequences. (B, T, D_in)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
            mask: Source mask. (B, T)
            left_context: Number of frames in left context.
            right_context: Number of frames in right context.
 
        Returns:
            x: Conv1d output sequences. (B, T, D_out)
            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_out)
 
        """
        x = torch.cat([self.cache, x.transpose(1, 2)], dim=2)
 
        if right_context > 0:
            self.cache = x[:, :, -(self.lorder + right_context) : -right_context]
        else:
            self.cache = x[:, :, -self.lorder :]
 
        x = self.conv(x)
 
        if self.batch_norm:
            x = self.bn(x)
 
        x = self.dropout(x)
 
        if self.relu:
            x = self.relu_func(x)
 
        x = x.transpose(1, 2)
 
        return x, self.out_pos(pos_enc)
 
    def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor:
        """Create new mask for output sequences.
 
        Args:
            mask: Mask of input sequences. (B, T)
 
        Returns:
            mask: Mask of output sequences. (B, sub(T))
 
        """
        if self.padding != 0:
            mask = mask[:, : -self.padding]
 
        return mask[:, :: self.stride]
 
    def create_new_pos_enc(self, pos_enc: torch.Tensor) -> torch.Tensor:
        """Create new positional embedding vector.
 
        Args:
            pos_enc: Input sequences positional embedding.
                     (B, 2 * (T - 1), D_in)
 
        Returns:
            pos_enc: Output sequences positional embedding.
                     (B, 2 * (sub(T) - 1), D_in)
 
        """
        pos_enc_positive = pos_enc[:, : pos_enc.size(1) // 2 + 1, :]
        pos_enc_negative = pos_enc[:, pos_enc.size(1) // 2 :, :]
 
        if self.padding != 0:
            pos_enc_positive = pos_enc_positive[:, : -self.padding, :]
            pos_enc_negative = pos_enc_negative[:, : -self.padding, :]
 
        pos_enc_positive = pos_enc_positive[:, :: self.stride, :]
        pos_enc_negative = pos_enc_negative[:, :: self.stride, :]
 
        pos_enc = torch.cat([pos_enc_positive, pos_enc_negative[:, 1:, :]], dim=1)
 
        return pos_enc