1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
| #!/usr/bin/env python3
| # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| import argparse
| import logging
| import os
| import sys
| from typing import Union, Dict, Any
|
| from funasr.utils import config_argparse
| from funasr.utils.cli_utils import get_commandline_args
| from funasr.utils.types import str2bool
| from funasr.utils.types import str2triple_str
| from funasr.utils.types import str_or_none
|
|
| def get_parser():
| parser = config_argparse.ArgumentParser(
| description="VAD Decoding",
| formatter_class=argparse.ArgumentDefaultsHelpFormatter,
| )
|
| # Note(kamo): Use '_' instead of '-' as separator.
| # '-' is confusing if written in yaml.
| parser.add_argument(
| "--log_level",
| type=lambda x: x.upper(),
| default="INFO",
| choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
| help="The verbose level of logging",
| )
|
| parser.add_argument("--output_dir", type=str, required=True)
| parser.add_argument(
| "--ngpu",
| type=int,
| default=0,
| help="The number of gpus. 0 indicates CPU mode",
| )
| parser.add_argument(
| "--njob",
| type=int,
| default=1,
| help="The number of jobs for each gpu",
| )
| parser.add_argument(
| "--gpuid_list",
| type=str,
| default="",
| help="The visible gpus",
| )
| parser.add_argument("--seed", type=int, default=0, help="Random seed")
| parser.add_argument(
| "--dtype",
| default="float32",
| choices=["float16", "float32", "float64"],
| help="Data type",
| )
| parser.add_argument(
| "--num_workers",
| type=int,
| default=1,
| help="The number of workers used for DataLoader",
| )
|
| group = parser.add_argument_group("Input data related")
| group.add_argument(
| "--data_path_and_name_and_type",
| type=str2triple_str,
| required=True,
| action="append",
| )
| group.add_argument("--key_file", type=str_or_none)
| group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
| group = parser.add_argument_group("The model configuration related")
| group.add_argument(
| "--vad_infer_config",
| type=str,
| help="VAD infer configuration",
| )
| group.add_argument(
| "--vad_model_file",
| type=str,
| help="VAD model parameter file",
| )
| group.add_argument(
| "--vad_cmvn_file",
| type=str,
| help="Global CMVN file",
| )
| group.add_argument(
| "--vad_train_config",
| type=str,
| help="VAD training configuration",
| )
|
| group = parser.add_argument_group("The inference configuration related")
| group.add_argument(
| "--batch_size",
| type=int,
| default=1,
| help="The batch size for inference",
| )
| return parser
|
|
| def inference_launch(mode, **kwargs):
| if mode == "vad":
| from funasr.bin.vad_inference import inference_modelscope
| return inference_modelscope(**kwargs)
| else:
| logging.info("Unknown decoding mode: {}".format(mode))
| return None
|
|
| def main(cmd=None):
| print(get_commandline_args(), file=sys.stderr)
| parser = get_parser()
| parser.add_argument(
| "--mode",
| type=str,
| default="vad",
| help="The decoding mode",
| )
| args = parser.parse_args(cmd)
| kwargs = vars(args)
| kwargs.pop("config", None)
|
| # set logging messages
| logging.basicConfig(
| level=args.log_level,
| format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
| )
| logging.info("Decoding args: {}".format(kwargs))
|
| # gpu setting
| if args.ngpu > 0:
| jobid = int(args.output_dir.split(".")[-1])
| gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
| os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
|
| inference_launch(**kwargs)
|
|
| if __name__ == "__main__":
| main()
|
|