1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
| import argparse
| import json
| import os
|
| if __name__ == '__main__':
| parser = argparse.ArgumentParser()
| parser.add_argument(
| "--task",
| type=str,
| default="auto-speech-recognition",
| help="task name",
| )
| parser.add_argument(
| "--type",
| type=str,
| default="generic-asr",
| )
| parser.add_argument(
| "--am_model_name",
| type=str,
| default="model.pb",
| help="model file name",
| )
| parser.add_argument(
| "--mode",
| type=str,
| default="paraformer",
| help="mode for decoding",
| )
| parser.add_argument(
| "--lang",
| type=str,
| default="zh-cn",
| help="language",
| )
| parser.add_argument(
| "--batch_size",
| type=int,
| default=1,
| help="batch size",
| )
| parser.add_argument(
| "--am_model_config",
| type=str,
| default="config.yaml",
| help="config file",
| )
| parser.add_argument(
| "--mvn_file",
| type=str,
| default="am.mvn",
| help="cmvn file",
| )
| parser.add_argument(
| "--model_name",
| type=str,
| help="model name",
| )
| parser.add_argument(
| "--pipeline_type",
| type=str,
| default="asr-inference",
| help="pipeline type",
| )
| parser.add_argument(
| "--vocab_size",
| type=int,
| help="vocab_size",
| )
| parser.add_argument(
| "--dataset",
| type=str,
| help="dataset name",
| )
| parser.add_argument(
| "--output_dir",
| type=str,
| help="output path",
| )
| parser.add_argument(
| "--nat",
| type=str,
| default="",
| help="nat",
| )
| parser.add_argument(
| "--tag",
| type=str,
| default="exp1",
| help="model name tag",
| )
| args = parser.parse_args()
|
| model = {
| "type": args.type,
| "am_model_name": args.am_model_name,
| "model_config": {
| "type": "pytorch",
| "code_base": "funasr",
| "mode": args.mode,
| "lang": args.lang,
| "batch_size": args.batch_size,
| "am_model_config": args.am_model_config,
| "mvn_file": args.mvn_file,
| "model": "speech_{}_asr{}-{}-16k-{}-vocab{}-pytorch-{}".format(args.model_name, args.nat, args.lang,
| args.dataset, args.vocab_size, args.tag),
| }
| }
| pipeline = {"type": args.pipeline_type}
| json_dict = {
| "framework": "pytorch",
| "task": args.task,
| "model": model,
| "pipeline": pipeline,
| }
|
| with open(os.path.join(args.output_dir, "configuration.json"), "w") as f:
| json.dump(json_dict, f, indent=4)
|
|