游雁
2023-11-20 df03a020f6d8fe4e9b09c1e784fead2852d90bfc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from funasr.models.ctc import CTC
from funasr.utils import config_argparse
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
 
 
def build_args(args, parser, extra_task_params):
    task_parser = config_argparse.ArgumentParser("Task related config")
    if args.task_name == "asr":
        from funasr.build_utils.build_asr_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--split_with_space",
            type=str2bool,
            default=True,
            help="whether to split text using <space>",
        )
        task_parser.add_argument(
            "--seg_dict_file",
            type=str,
            default=None,
            help="seg_dict_file for text processing",
        )
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        task_parser.add_argument(
            "--ctc_conf",
            action=NestedDictAction,
            default=get_default_kwargs(CTC),
            help="The keyword arguments for CTC class.",
        )
        task_parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The path of cmvn file.",
        )
 
    elif args.task_name == "pretrain":
        from funasr.build_utils.build_pretrain_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        task_parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The path of cmvn file.",
        )
 
    elif args.task_name == "lm":
        from funasr.build_utils.build_lm_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
 
    elif args.task_name == "punc":
        from funasr.build_utils.build_punc_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
 
    elif args.task_name == "vad":
        from funasr.build_utils.build_vad_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        task_parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The path of cmvn file.",
        )
 
    elif args.task_name == "diar":
        from funasr.build_utils.build_diar_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
 
    elif args.task_name == "sv":
        from funasr.build_utils.build_sv_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(task_parser)
        task_parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
 
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
 
    for action in parser._actions:
        if not any(action.dest == a.dest for a in task_parser._actions):
            task_parser._add_action(action)
 
    task_parser.set_defaults(**vars(args))
    task_args = task_parser.parse_args(extra_task_params)
    return task_args