Merge pull request #81 from alibaba-damo-academy/dev
Create vad_inference_launch.py
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from typing import Union, Dict, Any |
| | | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="VAD Decoding", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | # Note(kamo): Use '_' instead of '-' as separator. |
| | | # '-' is confusing if written in yaml. |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | | default="INFO", |
| | | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | | help="The verbose level of logging", |
| | | ) |
| | | |
| | | parser.add_argument("--output_dir", type=str, required=True) |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument( |
| | | "--njob", |
| | | type=int, |
| | | default=1, |
| | | help="The number of jobs for each gpu", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpuid_list", |
| | | type=str, |
| | | default="", |
| | | help="The visible gpus", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument( |
| | | "--dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type", |
| | | ) |
| | | parser.add_argument( |
| | | "--num_workers", |
| | | type=int, |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | required=True, |
| | | action="append", |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument( |
| | | "--vad_infer_config", |
| | | type=str, |
| | | help="VAD infer configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--vad_model_file", |
| | | type=str, |
| | | help="VAD model parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--vad_cmvn_file", |
| | | type=str, |
| | | help="Global CMVN file", |
| | | ) |
| | | group.add_argument( |
| | | "--vad_train_config", |
| | | type=str, |
| | | help="VAD training configuration", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("The inference configuration related") |
| | | group.add_argument( |
| | | "--batch_size", |
| | | type=int, |
| | | default=1, |
| | | help="The batch size for inference", |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | def inference_launch(mode, **kwargs): |
| | | if mode == "vad": |
| | | from funasr.bin.vad_inference import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | parser = get_parser() |
| | | parser.add_argument( |
| | | "--mode", |
| | | type=str, |
| | | default="vad", |
| | | help="The decoding mode", |
| | | ) |
| | | args = parser.parse_args(cmd) |
| | | kwargs = vars(args) |
| | | kwargs.pop("config", None) |
| | | |
| | | # set logging messages |
| | | logging.basicConfig( |
| | | level=args.log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | logging.info("Decoding args: {}".format(kwargs)) |
| | | |
| | | # gpu setting |
| | | if args.ngpu > 0: |
| | | jobid = int(args.output_dir.split(".")[-1]) |
| | | gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob] |
| | | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | | os.environ["CUDA_VISIBLE_DEVICES"] = gpuid |
| | | |
| | | inference_launch(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | |
| | | write_flag = False |
| | | if write_flag: |
| | | f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim)))) |
| | | f.flush() |
| | | |
| | | |
| | | def calc_shape(data_dir, dataset, frontend_conf, speech_length_min=-1, speech_length_max=-1, nj=32): |