Yabin Li
2023-05-08 8a08405b668e06c4670b4c13f6793e193f21a21d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
 
import yaml
 
 
def update_dct(fin_configs, root):
    if root == {}:
        return {}
    for root_key, root_value in root.items():
        if not isinstance(root[root_key], dict):
            fin_configs[root_key] = root[root_key]
        else:
            if root_key in fin_configs.keys():
                result = update_dct(fin_configs[root_key], root[root_key])
                fin_configs[root_key] = result
            else:
                fin_configs[root_key] = root[root_key]
    return fin_configs
 
 
def parse_args(mode):
    if mode == "asr":
        from funasr.tasks.asr import ASRTask as ASRTask
    elif mode == "paraformer":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "paraformer_vad_punc":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "uniasr":
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
    elif mode == "mfcca":
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
    elif mode == "tp":
        from funasr.tasks.asr import ASRTaskAligner as ASRTask
    else:
        raise ValueError("Unknown mode: {}".format(mode))
    parser = ASRTask.get_parser()
    args = parser.parse_args()
    return args, ASRTask
 
 
def build_trainer(modelscope_dict,
                  data_dir,
                  output_dir,
                  train_set="train",
                  dev_set="validation",
                  distributed=False,
                  dataset_type="small",
                  batch_bins=None,
                  max_epoch=None,
                  optim=None,
                  lr=None,
                  scheduler=None,
                  scheduler_conf=None,
                  specaug=None,
                  specaug_conf=None,
                  param_dict=None,
                  **kwargs):
    mode = modelscope_dict['mode']
    args, ASRTask = parse_args(mode=mode)
    # ddp related
    if args.local_rank is not None:
        distributed = True
    else:
        distributed = False
    args.local_rank = args.local_rank if args.local_rank is not None else 0
    local_rank = args.local_rank
    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
 
    config = modelscope_dict['am_model_config']
    finetune_config = modelscope_dict['finetune_config']
    init_param = modelscope_dict['init_model']
    cmvn_file = modelscope_dict['cmvn_file']
    seg_dict_file = modelscope_dict['seg_dict']
 
    # overwrite parameters
    with open(config) as f:
        configs = yaml.safe_load(f)
    with open(finetune_config) as f:
        finetune_configs = yaml.safe_load(f)
        # set data_types
        if dataset_type == "large":
            if 'data_types' not in finetune_configs['dataset_conf']:
                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
    finetune_configs = update_dct(configs, finetune_configs)
    for key, value in finetune_configs.items():
        if hasattr(args, key):
            setattr(args, key, value)
 
    # prepare data
    args.dataset_type = dataset_type
    if args.dataset_type == "small":
        args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
                                                  ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
        args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
                                                  ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
    elif args.dataset_type == "large":
        args.train_data_file = None
        args.valid_data_file = None
    else:
        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
    args.init_param = [init_param]
    args.cmvn_file = cmvn_file
    if os.path.exists(seg_dict_file):
        args.seg_dict_file = seg_dict_file
    else:
        args.seg_dict_file = None
    args.data_dir = data_dir
    args.train_set = train_set
    args.dev_set = dev_set
    args.output_dir = output_dir
    args.gpu_id = args.local_rank
    args.config = finetune_config
    if optim is not None:
        args.optim = optim
    if lr is not None:
        args.optim_conf["lr"] = lr
    if scheduler is not None:
        args.scheduler = scheduler
    if scheduler_conf is not None:
        args.scheduler_conf = scheduler_conf
    if specaug is not None:
        args.specaug = specaug
    if specaug_conf is not None:
        args.specaug_conf = specaug_conf
    if max_epoch is not None:
        args.max_epoch = max_epoch
    if batch_bins is not None:
        if args.dataset_type == "small":
            args.batch_bins = batch_bins
        elif args.dataset_type == "large":
            args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
        else:
            raise ValueError(f"Not supported dataset_type={args.dataset_type}")
    if args.normalize in ["null", "none", "None"]:
        args.normalize = None
    if args.patience in ["null", "none", "None"]:
        args.patience = None
    args.local_rank = local_rank
    args.distributed = distributed
    ASRTask.finetune_args = args
 
    return ASRTask