huangmingming
2023-01-30 adcee8828ef5d78b575043954deb662a35e318f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
 
import yaml
 
 
def update_dct(fin_configs, root):
    if root == {}:
        return {}
    for root_key, root_value in root.items():
        if not isinstance(root[root_key], dict):
            fin_configs[root_key] = root[root_key]
        else:
            if root_key in fin_configs.keys():
                result = update_dct(fin_configs[root_key], root[root_key])
                fin_configs[root_key] = result
            else:
                fin_configs[root_key] = root[root_key]
    return fin_configs
 
 
def parse_args(mode):
    if mode == "asr":
        from funasr.tasks.asr import ASRTask as ASRTask
    elif mode == "paraformer":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "paraformer_vad_punc":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "uniasr":
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
    else:
        raise ValueError("Unknown mode: {}".format(mode))
    parser = ASRTask.get_parser()
    args = parser.parse_args()
    return args, ASRTask
 
 
def build_trainer(modelscope_dict, data_dir, output_dir, train_set="train", dev_set="validation", distributed=False,
                  dataset_type="small", lr=None, batch_bins=None, max_epoch=None, mate_params=None):
    mode = modelscope_dict['mode']
    args, ASRTask = parse_args(mode=mode)
    # ddp related
    if args.local_rank is not None:
        distributed = True
    else:
        distributed = False
    args.local_rank = args.local_rank if args.local_rank is not None else 0
    local_rank = args.local_rank
    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
 
    config = modelscope_dict['am_model_config']
    finetune_config = modelscope_dict['finetune_config']
    init_param = modelscope_dict['init_model']
    cmvn_file = modelscope_dict['cmvn_file']
    seg_dict_file = modelscope_dict['seg_dict']
 
    # overwrite parameters
    with open(config) as f:
        configs = yaml.safe_load(f)
    with open(finetune_config) as f:
        finetune_configs = yaml.safe_load(f)
        # set data_types
        if dataset_type == "large":
            finetune_configs["dataset_conf"]["data_types"] = "sound,text"
    finetune_configs = update_dct(configs, finetune_configs)
    for key, value in finetune_configs.items():
        if hasattr(args, key):
            setattr(args, key, value)
 
    # prepare data
    args.dataset_type = dataset_type
    if args.dataset_type == "small":
        args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
                                                  ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
        args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
                                                  ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
    elif args.dataset_type == "large":
        args.train_data_file = None
        args.valid_data_file = None
    else:
        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
    args.init_param = [init_param]
    args.cmvn_file = cmvn_file
    if os.path.exists(seg_dict_file):
        args.seg_dict_file = seg_dict_file
    else:
        args.seg_dict_file = None
    args.data_dir = data_dir
    args.train_set = train_set
    args.dev_set = dev_set
    args.output_dir = output_dir
    args.gpu_id = args.local_rank
    args.config = finetune_config
    if lr is not None:
        args.optim_conf["lr"] = lr
    if max_epoch is not None:
        args.max_epoch = max_epoch
    if batch_bins is not None:
        if args.dataset_type == "small":
            args.batch_bins = batch_bins
        elif args.dataset_type == "large":
            args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
        else:
            raise ValueError(f"Not supported dataset_type={args.dataset_type}")
    if args.normalize in ["null", "none", "None"]:
        args.normalize = None
    if args.patience in ["null", "none", "None"]:
        args.patience = None
    args.local_rank = local_rank
    args.distributed = distributed
    ASRTask.finetune_args = args
 
    return ASRTask