游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.data2vec import Data2VecPretrainModel
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.specaug.specaug import SpecAug
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
 
frontend_choices = ClassChoices(
    name="frontend",
    classes=dict(
        default=DefaultFrontend,
        sliding_window=SlidingWindow,
        wav_frontend=WavFrontend,
    ),
    default="default",
)
specaug_choices = ClassChoices(
    name="specaug",
    classes=dict(specaug=SpecAug),
    default=None,
    optional=True,
)
normalize_choices = ClassChoices(
    "normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    default=None,
    optional=True,
)
encoder_choices = ClassChoices(
    "encoder",
    classes=dict(
        data2vec_encoder=Data2VecEncoder,
    ),
    default="data2vec_encoder",
)
model_choices = ClassChoices(
    "model",
    classes=dict(
        data2vec=Data2VecPretrainModel,
    ),
    default="data2vec",
)
class_choices_list = [
    # --frontend and --frontend_conf
    frontend_choices,
    # --specaug and --specaug_conf
    specaug_choices,
    # --normalize and --normalize_conf
    normalize_choices,
    # --encoder and --encoder_conf
    encoder_choices,
    # --model and --model_conf
    model_choices,
]
 
 
def build_pretrain_model(args):
    # frontend
    if args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
 
    # data augmentation for spectrogram
    if args.specaug is not None:
        specaug_class = specaug_choices.get_class(args.specaug)
        specaug = specaug_class(**args.specaug_conf)
    else:
        specaug = None
 
    # normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
 
    # encoder
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(
        input_size=input_size,
        **args.encoder_conf,
    )
 
    if args.model == "data2vec":
        model_class = model_choices.get_class("data2vec")
        model = model_class(
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
        )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
 
    # initialize
    if args.init is not None:
        initialize(model, args.init)
 
    return model