aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""ConvInput block for Transducer encoder."""
 
from typing import Optional, Tuple, Union
 
import torch
import math
 
from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
 
 
class ConvInput(torch.nn.Module):
    """ConvInput module definition.
 
    Args:
        input_size: Input size.
        conv_size: Convolution size.
        subsampling_factor: Subsampling factor.
        vgg_like: Whether to use a VGG-like network.
        output_size: Block output dimension.
 
    """
 
    def __init__(
        self,
        input_size: int,
        conv_size: Union[int, Tuple],
        subsampling_factor: int = 4,
        vgg_like: bool = True,
        output_size: Optional[int] = None,
    ) -> None:
        """Construct a ConvInput object."""
        super().__init__()
        if vgg_like:
            if subsampling_factor == 1:
                conv_size1, conv_size2 = conv_size
 
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                )
 
                output_proj = conv_size2 * ((input_size // 2) // 2)
 
                self.subsampling_factor = 1
 
                self.stride_1 = 1
 
                self.create_new_mask = self.create_new_vgg_mask
 
            else:
                conv_size1, conv_size2 = conv_size
 
                kernel_1 = int(subsampling_factor / 2)
 
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((kernel_1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((2, 2)),
                )
 
                output_proj = conv_size2 * ((input_size // 2) // 2)
 
                self.subsampling_factor = subsampling_factor
 
                self.create_new_mask = self.create_new_vgg_mask
                
                self.stride_1 = kernel_1
 
        else:
            if subsampling_factor == 1:
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
                    torch.nn.ReLU(),
                )
 
                output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
 
                self.subsampling_factor = subsampling_factor
                self.kernel_2 = 3
                self.stride_2 = 1
 
                self.create_new_mask = self.create_new_conv2d_mask
 
            else:
                kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
                    subsampling_factor,
                    input_size,
                )
 
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, 2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
                    torch.nn.ReLU(),
                )
 
                output_proj = conv_size * conv_2_output_size
 
                self.subsampling_factor = subsampling_factor
                self.kernel_2 = kernel_2
                self.stride_2 = stride_2
 
                self.create_new_mask = self.create_new_conv2d_mask
 
        self.vgg_like = vgg_like
        self.min_frame_length = 7
 
        if output_size is not None:
            self.output = torch.nn.Linear(output_proj, output_size)
            self.output_size = output_size
        else:
            self.output = None
            self.output_size = output_proj
 
    def forward(
        self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input sequences.
 
        Args:
            x: ConvInput input sequences. (B, T, D_feats)
            mask: Mask of input sequences. (B, 1, T)
 
        Returns:
            x: ConvInput output sequences. (B, sub(T), D_out)
            mask: Mask of output sequences. (B, 1, sub(T))
 
        """
        if mask is not None:
            mask = self.create_new_mask(mask)
            olens = max(mask.eq(0).sum(1))
 
        b, t, f = x.size()
        x = x.unsqueeze(1) # (b. 1. t. f)
 
        if chunk_size is not None:
            max_input_length = int(
                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
            )
            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
            x = list(x)
            x = torch.stack(x, dim=0)
            N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
            x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
 
        x = self.conv(x)
 
        _, c, _, f = x.size()
        if chunk_size is not None:
            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
        else:
            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
 
        if self.output is not None:
            x = self.output(x)
 
        return x, mask[:,:olens][:,:x.size(1)]
 
    def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
        """Create a new mask for VGG output sequences.
 
        Args:
            mask: Mask of input sequences. (B, T)
 
        Returns:
            mask: Mask of output sequences. (B, sub(T))
 
        """
        if self.subsampling_factor > 1:
            vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
            mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
 
            vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
            mask = mask[:, :vgg2_t_len][:, ::2]
        else:
            mask = mask
 
        return mask
 
    def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
        """Create new conformer mask for Conv2d output sequences.
 
        Args:
            mask: Mask of input sequences. (B, T)
 
        Returns:
            mask: Mask of output sequences. (B, sub(T))
 
        """
        if self.subsampling_factor > 1:
            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
        else:
            return mask
 
    def get_size_before_subsampling(self, size: int) -> int:
        """Return the original size before subsampling for a given size.
 
        Args:
            size: Number of frames after subsampling.
 
        Returns:
            : Number of frames before subsampling.
 
        """
        return size * self.subsampling_factor