游雁
2023-11-09 adf32376629f6940c84b62167bee6c252e6c2fcc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import json
import os
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--task",
        type=str,
        default="auto-speech-recognition",
        help="task name",
    )
    parser.add_argument(
        "--type",
        type=str,
        default="generic-asr",
    )
    parser.add_argument(
        "--am_model_name",
        type=str,
        default="model.pb",
        help="model file name",
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="paraformer",
        help="mode for decoding",
    )
    parser.add_argument(
        "--lang",
        type=str,
        default="zh-cn",
        help="language",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="batch size",
    )
    parser.add_argument(
        "--am_model_config",
        type=str,
        default="config.yaml",
        help="config file",
    )
    parser.add_argument(
        "--mvn_file",
        type=str,
        default="am.mvn",
        help="cmvn file",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="model name",
    )
    parser.add_argument(
        "--pipeline_type",
        type=str,
        default="asr-inference",
        help="pipeline type",
    )
    parser.add_argument(
        "--vocab_size",
        type=int,
        help="vocab_size",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        help="dataset name",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="output path",
    )
    parser.add_argument(
        "--nat",
        type=str,
        default="",
        help="nat",
    )
    parser.add_argument(
        "--tag",
        type=str,
        default="exp1",
        help="model name tag",
    )
    args = parser.parse_args()
 
    model = {
        "type": args.type,
        "am_model_name": args.am_model_name,
        "model_config": {
            "type": "pytorch",
            "code_base": "funasr",
            "mode": args.mode,
            "lang": args.lang,
            "batch_size": args.batch_size,
            "am_model_config": args.am_model_config,
            "mvn_file": args.mvn_file,
            "model": "speech_{}_asr{}-{}-16k-{}-vocab{}-pytorch-{}".format(args.model_name, args.nat, args.lang,
                                                                           args.dataset, args.vocab_size, args.tag),
        }
    }
    pipeline = {"type": args.pipeline_type}
    json_dict = {
        "framework": "pytorch",
        "task": args.task,
        "model": model,
        "pipeline": pipeline,
    }
 
    with open(os.path.join(args.output_dir, "configuration.json"), "w") as f:
        json.dump(json_dict, f, indent=4)