游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import logging
 
import torch
 
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.base_model import FunASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.sv_decoder import DenseDecoder
from funasr.models.e2e_sv import ESPnetSVModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.pooling.statistic_pooling import StatisticPooling
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
    HuggingFaceTransformersPostEncoder,  # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
 
frontend_choices = ClassChoices(
    name="frontend",
    classes=dict(
        default=DefaultFrontend,
        sliding_window=SlidingWindow,
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
    ),
    type_check=AbsFrontend,
    default="default",
)
specaug_choices = ClassChoices(
    name="specaug",
    classes=dict(
        specaug=SpecAug,
    ),
    type_check=AbsSpecAug,
    default=None,
    optional=True,
)
normalize_choices = ClassChoices(
    "normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    type_check=AbsNormalize,
    default=None,
    optional=True,
)
model_choices = ClassChoices(
    "model",
    classes=dict(
        espnet=ESPnetSVModel,
    ),
    type_check=FunASRModel,
    default="espnet",
)
preencoder_choices = ClassChoices(
    name="preencoder",
    classes=dict(
        sinc=LightweightSincConvs,
        linear=LinearProjection,
    ),
    type_check=AbsPreEncoder,
    default=None,
    optional=True,
)
encoder_choices = ClassChoices(
    "encoder",
    classes=dict(
        resnet34=ResNet34,
        resnet34_sp_l2reg=ResNet34_SP_L2Reg,
        rnn=RNNEncoder,
    ),
    type_check=AbsEncoder,
    default="resnet34",
)
postencoder_choices = ClassChoices(
    name="postencoder",
    classes=dict(
        hugging_face_transformers=HuggingFaceTransformersPostEncoder,
    ),
    type_check=AbsPostEncoder,
    default=None,
    optional=True,
)
pooling_choices = ClassChoices(
    name="pooling_type",
    classes=dict(
        statistic=StatisticPooling,
    ),
    type_check=torch.nn.Module,
    default="statistic",
)
decoder_choices = ClassChoices(
    "decoder",
    classes=dict(
        dense=DenseDecoder,
    ),
    type_check=AbsDecoder,
    default="dense",
)
 
class_choices_list = [
    # --frontend and --frontend_conf
    frontend_choices,
    # --specaug and --specaug_conf
    specaug_choices,
    # --normalize and --normalize_conf
    normalize_choices,
    # --model and --model_conf
    model_choices,
    # --preencoder and --preencoder_conf
    preencoder_choices,
    # --encoder and --encoder_conf
    encoder_choices,
    # --postencoder and --postencoder_conf
    postencoder_choices,
    # --pooling and --pooling_conf
    pooling_choices,
    # --decoder and --decoder_conf
    decoder_choices,
]
 
 
def build_sv_model(args):
    # token_list
    if isinstance(args.token_list, str):
        with open(args.token_list, encoding="utf-8") as f:
            token_list = [line.rstrip() for line in f]
 
        # Overwriting token_list to keep it as "portable".
        args.token_list = list(token_list)
    elif isinstance(args.token_list, (tuple, list)):
        token_list = list(args.token_list)
    else:
        raise RuntimeError("token_list must be str or list")
    vocab_size = len(token_list)
    logging.info(f"Speaker number: {vocab_size}")
 
    # 1. frontend
    if args.input_size is None:
        # Extract features in the model
        frontend_class = frontend_choices.get_class(args.frontend)
        frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        # Give features from data-loader
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
 
    # 2. Data augmentation for spectrogram
    if args.specaug is not None:
        specaug_class = specaug_choices.get_class(args.specaug)
        specaug = specaug_class(**args.specaug_conf)
    else:
        specaug = None
 
    # 3. Normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
 
    # 4. Pre-encoder input block
    # NOTE(kan-bayashi): Use getattr to keep the compatibility
    if getattr(args, "preencoder", None) is not None:
        preencoder_class = preencoder_choices.get_class(args.preencoder)
        preencoder = preencoder_class(**args.preencoder_conf)
        input_size = preencoder.output_size()
    else:
        preencoder = None
 
    # 5. Encoder
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
 
    # 6. Post-encoder block
    # NOTE(kan-bayashi): Use getattr to keep the compatibility
    encoder_output_size = encoder.output_size()
    if getattr(args, "postencoder", None) is not None:
        postencoder_class = postencoder_choices.get_class(args.postencoder)
        postencoder = postencoder_class(
            input_size=encoder_output_size, **args.postencoder_conf
        )
        encoder_output_size = postencoder.output_size()
    else:
        postencoder = None
 
    # 7. Pooling layer
    pooling_class = pooling_choices.get_class(args.pooling_type)
    pooling_dim = (2, 3)
    eps = 1e-12
    if hasattr(args, "pooling_type_conf"):
        if "pooling_dim" in args.pooling_type_conf:
            pooling_dim = args.pooling_type_conf["pooling_dim"]
        if "eps" in args.pooling_type_conf:
            eps = args.pooling_type_conf["eps"]
    pooling_layer = pooling_class(
        pooling_dim=pooling_dim,
        eps=eps,
    )
    if args.pooling_type == "statistic":
        encoder_output_size *= 2
 
    # 8. Decoder
    decoder_class = decoder_choices.get_class(args.decoder)
    decoder = decoder_class(
        vocab_size=vocab_size,
        encoder_output_size=encoder_output_size,
        **args.decoder_conf,
    )
 
    # 7. Build model
    try:
        model_class = model_choices.get_class(args.model)
    except AttributeError:
        model_class = model_choices.get_class("espnet")
    model = model_class(
        vocab_size=vocab_size,
        token_list=token_list,
        frontend=frontend,
        specaug=specaug,
        normalize=normalize,
        preencoder=preencoder,
        encoder=encoder,
        postencoder=postencoder,
        pooling_layer=pooling_layer,
        decoder=decoder,
        **args.model_conf,
    )
 
    # FIXME(kamo): Should be done in model?
    # 8. Initialize
    if args.init is not None:
        initialize(model, args.init)
 
    return model