aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
"""Set of methods to build Transducer encoder architecture."""
 
from typing import Any, Dict, List, Optional, Union
 
from funasr.modules.activation import get_activation
from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
from funasr.models.encoder.chunk_encoder_modules.attention import (  # noqa: H301
    RelPositionMultiHeadedAttention,
)
from funasr.models.encoder.chunk_encoder_modules.convolution import (  # noqa: H301
    ConformerConvolution,
    ConvolutionalSpatialGatingUnit,
)
from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
from funasr.models.encoder.chunk_encoder_modules.positional_encoding import (  # noqa: H301
    RelPositionalEncoding,
)
from funasr.modules.positionwise_feed_forward import (
    PositionwiseFeedForward,
)
 
 
def build_main_parameters(
    pos_wise_act_type: str = "swish",
    conv_mod_act_type: str = "swish",
    pos_enc_dropout_rate: float = 0.0,
    pos_enc_max_len: int = 5000,
    simplified_att_score: bool = False,
    norm_type: str = "layer_norm",
    conv_mod_norm_type: str = "layer_norm",
    after_norm_eps: Optional[float] = None,
    after_norm_partial: Optional[float] = None,
    dynamic_chunk_training: bool = False,
    short_chunk_threshold: float = 0.75,
    short_chunk_size: int = 25,
    left_chunk_size: int = 0,
    time_reduction_factor: int = 1,
    unified_model_training: bool = False,
    default_chunk_size: int = 16,
    jitter_range: int =4,
    **activation_parameters,
) -> Dict[str, Any]:
    """Build encoder main parameters.
 
    Args:
        pos_wise_act_type: Conformer position-wise feed-forward activation type.
        conv_mod_act_type: Conformer convolution module activation type.
        pos_enc_dropout_rate: Positional encoding dropout rate.
        pos_enc_max_len: Positional encoding maximum length.
        simplified_att_score: Whether to use simplified attention score computation.
        norm_type: X-former normalization module type.
        conv_mod_norm_type: Conformer convolution module normalization type.
        after_norm_eps: Epsilon value for the final normalization.
        after_norm_partial: Value for the final normalization with RMSNorm.
        dynamic_chunk_training: Whether to use dynamic chunk training.
        short_chunk_threshold: Threshold for dynamic chunk selection.
        short_chunk_size: Minimum number of frames during dynamic chunk training.
        left_chunk_size: Number of frames in left context.
        **activations_parameters: Parameters of the activation functions.
                                    (See espnet2/asr_transducer/activation.py)
 
    Returns:
        : Main encoder parameters
 
    """
    main_params = {}
 
    main_params["pos_wise_act"] = get_activation(
        pos_wise_act_type, **activation_parameters
    )
 
    main_params["conv_mod_act"] = get_activation(
        conv_mod_act_type, **activation_parameters
    )
 
    main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate
    main_params["pos_enc_max_len"] = pos_enc_max_len
 
    main_params["simplified_att_score"] = simplified_att_score
 
    main_params["norm_type"] = norm_type
    main_params["conv_mod_norm_type"] = conv_mod_norm_type
 
    (
        main_params["after_norm_class"],
        main_params["after_norm_args"],
    ) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial)
 
    main_params["dynamic_chunk_training"] = dynamic_chunk_training
    main_params["short_chunk_threshold"] = max(0, short_chunk_threshold)
    main_params["short_chunk_size"] = max(0, short_chunk_size)
    main_params["left_chunk_size"] = max(0, left_chunk_size)
    
    main_params["unified_model_training"] = unified_model_training
    main_params["default_chunk_size"] = max(0, default_chunk_size)
    main_params["jitter_range"] = max(0, jitter_range)
   
    main_params["time_reduction_factor"] = time_reduction_factor
 
    return main_params
 
 
def build_positional_encoding(
    block_size: int, configuration: Dict[str, Any]
) -> RelPositionalEncoding:
    """Build positional encoding block.
 
    Args:
        block_size: Input/output size.
        configuration: Positional encoding configuration.
 
    Returns:
        : Positional encoding module.
 
    """
    return RelPositionalEncoding(
        block_size,
        configuration.get("pos_enc_dropout_rate", 0.0),
        max_len=configuration.get("pos_enc_max_len", 5000),
    )
 
 
def build_input_block(
    input_size: int,
    configuration: Dict[str, Union[str, int]],
) -> ConvInput:
    """Build encoder input block.
 
    Args:
        input_size: Input size.
        configuration: Input block configuration.
 
    Returns:
        : ConvInput block function.
 
    """
    if configuration["linear"]:
        return LinearInput(
            input_size,
            configuration["output_size"],
            configuration["subsampling_factor"],
        )
    else:
        return ConvInput(
            input_size,
            configuration["conv_size"],
            configuration["subsampling_factor"],
            vgg_like=configuration["vgg_like"],
            output_size=configuration["output_size"],
        )
 
 
def build_branchformer_block(
    configuration: List[Dict[str, Any]],
    main_params: Dict[str, Any],
) -> Conformer:
    """Build Branchformer block.
 
    Args:
        configuration: Branchformer block configuration.
        main_params: Encoder main parameters.
 
    Returns:
        : Branchformer block function.
 
    """
    hidden_size = configuration["hidden_size"]
    linear_size = configuration["linear_size"]
 
    dropout_rate = configuration.get("dropout_rate", 0.0)
 
    conv_mod_norm_class, conv_mod_norm_args = get_normalization(
        main_params["conv_mod_norm_type"],
        eps=configuration.get("conv_mod_norm_eps"),
        partial=configuration.get("conv_mod_norm_partial"),
    )
 
    conv_mod_args = (
        linear_size,
        configuration["conv_mod_kernel_size"],
        conv_mod_norm_class,
        conv_mod_norm_args,
        dropout_rate,
        main_params["dynamic_chunk_training"],
    )
 
    mult_att_args = (
        configuration.get("heads", 4),
        hidden_size,
        configuration.get("att_dropout_rate", 0.0),
        main_params["simplified_att_score"],
    )
 
    norm_class, norm_args = get_normalization(
        main_params["norm_type"],
        eps=configuration.get("norm_eps"),
        partial=configuration.get("norm_partial"),
    )
 
    return lambda: Branchformer(
        hidden_size,
        linear_size,
        RelPositionMultiHeadedAttention(*mult_att_args),
        ConvolutionalSpatialGatingUnit(*conv_mod_args),
        norm_class=norm_class,
        norm_args=norm_args,
        dropout_rate=dropout_rate,
    )
 
 
def build_conformer_block(
    configuration: List[Dict[str, Any]],
    main_params: Dict[str, Any],
) -> Conformer:
    """Build Conformer block.
 
    Args:
        configuration: Conformer block configuration.
        main_params: Encoder main parameters.
 
    Returns:
        : Conformer block function.
 
    """
    hidden_size = configuration["hidden_size"]
    linear_size = configuration["linear_size"]
 
    pos_wise_args = (
        hidden_size,
        linear_size,
        configuration.get("pos_wise_dropout_rate", 0.0),
        main_params["pos_wise_act"],
    )
 
    conv_mod_norm_args = {
        "eps": configuration.get("conv_mod_norm_eps", 1e-05),
        "momentum": configuration.get("conv_mod_norm_momentum", 0.1),
    }
 
    conv_mod_args = (
        hidden_size,
        configuration["conv_mod_kernel_size"],
        main_params["conv_mod_act"],
        conv_mod_norm_args,
        main_params["dynamic_chunk_training"] or main_params["unified_model_training"],
    )
 
    mult_att_args = (
        configuration.get("heads", 4),
        hidden_size,
        configuration.get("att_dropout_rate", 0.0),
        main_params["simplified_att_score"],
    )
 
    norm_class, norm_args = get_normalization(
        main_params["norm_type"],
        eps=configuration.get("norm_eps"),
        partial=configuration.get("norm_partial"),
    )
 
    return lambda: Conformer(
        hidden_size,
        RelPositionMultiHeadedAttention(*mult_att_args),
        PositionwiseFeedForward(*pos_wise_args),
        PositionwiseFeedForward(*pos_wise_args),
        ConformerConvolution(*conv_mod_args),
        norm_class=norm_class,
        norm_args=norm_args,
        dropout_rate=configuration.get("dropout_rate", 0.0),
    )
 
 
def build_conv1d_block(
    configuration: List[Dict[str, Any]],
    causal: bool,
) -> Conv1d:
    """Build Conv1d block.
 
    Args:
        configuration: Conv1d block configuration.
 
    Returns:
        : Conv1d block function.
 
    """
    return lambda: Conv1d(
        configuration["input_size"],
        configuration["output_size"],
        configuration["kernel_size"],
        stride=configuration.get("stride", 1),
        dilation=configuration.get("dilation", 1),
        groups=configuration.get("groups", 1),
        bias=configuration.get("bias", True),
        relu=configuration.get("relu", True),
        batch_norm=configuration.get("batch_norm", False),
        causal=causal,
        dropout_rate=configuration.get("dropout_rate", 0.0),
    )
 
 
def build_body_blocks(
    configuration: List[Dict[str, Any]],
    main_params: Dict[str, Any],
    output_size: int,
) -> MultiBlocks:
    """Build encoder body blocks.
 
    Args:
        configuration: Body blocks configuration.
        main_params: Encoder main parameters.
        output_size: Architecture output size.
 
    Returns:
        MultiBlocks function encapsulation all encoder blocks.
 
    """
    fn_modules = []
    extended_conf = []
 
    for c in configuration:
        if c.get("num_blocks") is not None:
            extended_conf += c["num_blocks"] * [
                {c_i: c[c_i] for c_i in c if c_i != "num_blocks"}
            ]
        else:
            extended_conf += [c]
 
    for i, c in enumerate(extended_conf):
        block_type = c["block_type"]
 
        if block_type == "branchformer":
            module = build_branchformer_block(c, main_params)
        elif block_type == "conformer":
            module = build_conformer_block(c, main_params)
        elif block_type == "conv1d":
            module = build_conv1d_block(c, main_params["dynamic_chunk_training"])
        else:
            raise NotImplementedError
 
        fn_modules.append(module)
 
    return MultiBlocks(
        [fn() for fn in fn_modules],
        output_size,
        norm_class=main_params["after_norm_class"],
        norm_args=main_params["after_norm_args"],
    )