aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Set of methods to validate encoder architecture."""
 
from typing import Any, Dict, List, Tuple
 
from funasr.modules.nets_utils import sub_factor_to_params
 
 
def validate_block_arguments(
    configuration: Dict[str, Any],
    block_id: int,
    previous_block_output: int,
) -> Tuple[int, int]:
    """Validate block arguments.
 
    Args:
        configuration: Architecture configuration.
        block_id: Block ID.
        previous_block_output: Previous block output size.
 
    Returns:
        input_size: Block input size.
        output_size: Block output size.
 
    """
    block_type = configuration.get("block_type")
 
    if block_type is None:
        raise ValueError(
            "Block %d in encoder doesn't have a type assigned. " % block_id
        )
 
    if block_type in ["branchformer", "conformer"]:
        if configuration.get("linear_size") is None:
            raise ValueError(
                "Missing 'linear_size' argument for X-former block (ID: %d)" % block_id
            )
 
        if configuration.get("conv_mod_kernel_size") is None:
            raise ValueError(
                "Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)"
                % block_id
            )
 
        input_size = configuration.get("hidden_size")
        output_size = configuration.get("hidden_size")
 
    elif block_type == "conv1d":
        output_size = configuration.get("output_size")
 
        if output_size is None:
            raise ValueError(
                "Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id
            )
 
        if configuration.get("kernel_size") is None:
            raise ValueError(
                "Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id
            )
 
        input_size = configuration["input_size"] = previous_block_output
    else:
        raise ValueError("Block type: %s is not supported." % block_type)
 
    return input_size, output_size
 
 
def validate_input_block(
    configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int
) -> int:
    """Validate input block.
 
    Args:
        configuration: Encoder input block configuration.
        body_first_conf: Encoder first body block configuration.
        input_size: Encoder input block input size.
 
    Return:
        output_size: Encoder input block output size.
 
    """
    vgg_like = configuration.get("vgg_like", False)
    linear = configuration.get("linear", False)
    next_block_type = body_first_conf.get("block_type")
    allowed_next_block_type = ["branchformer", "conformer", "conv1d"]
 
    if next_block_type is None or (next_block_type not in allowed_next_block_type):
        return -1
 
    if configuration.get("subsampling_factor") is None:
        configuration["subsampling_factor"] = 4
 
    if vgg_like:
        conv_size = configuration.get("conv_size", (64, 128))
 
        if isinstance(conv_size, int):
            conv_size = (conv_size, conv_size)
    else:
        conv_size = configuration.get("conv_size", None)
 
        if isinstance(conv_size, tuple):
            conv_size = conv_size[0]
 
    if next_block_type == "conv1d":
        if vgg_like:
            output_size = conv_size[1] * ((input_size // 2) // 2)
        else:
            if conv_size is None:
                conv_size = body_first_conf.get("output_size", 64)
 
            sub_factor = configuration["subsampling_factor"]
 
            _, _, conv_osize = sub_factor_to_params(sub_factor, input_size)
            assert (
                conv_osize > 0
            ), "Conv2D output size is <1 with input size %d and subsampling %d" % (
                input_size,
                sub_factor,
            )
 
            output_size = conv_osize * conv_size
 
        configuration["output_size"] = None
    else:
        output_size = body_first_conf.get("hidden_size")
 
        if conv_size is None:
            conv_size = output_size
 
        configuration["output_size"] = output_size
 
    configuration["conv_size"] = conv_size
    configuration["vgg_like"] = vgg_like
    configuration["linear"] = linear
 
    return output_size
 
 
def validate_architecture(
    input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int
) -> Tuple[int, int]:
    """Validate specified architecture is valid.
 
    Args:
        input_conf: Encoder input block configuration.
        body_conf: Encoder body blocks configuration.
        input_size: Encoder input size.
 
    Returns:
        input_block_osize: Encoder input block output size.
        : Encoder body block output size.
 
    """
    input_block_osize = validate_input_block(input_conf, body_conf[0], input_size)
 
    cmp_io = []
 
    for i, b in enumerate(body_conf):
        _io = validate_block_arguments(
            b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1]
        )
 
        cmp_io.append(_io)
 
    for i in range(1, len(cmp_io)):
        if cmp_io[(i - 1)][1] != cmp_io[i][0]:
            raise ValueError(
                "Output/Input mismatch between blocks %d and %d"
                " in the encoder body." % ((i - 1), i)
            )
 
    return input_block_osize, cmp_io[-1][1]