嘉渊
2023-05-25 60e592456ac63f8e5d4c9a749629039fb4773b9a
update repo
1个文件已添加
118 ■■■■■ 已修改文件
egs/aishell2/transformer/utils/gen_modelscope_configuration.py 118 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell2/transformer/utils/gen_modelscope_configuration.py
New file
@@ -0,0 +1,118 @@
import argparse
import json
import os
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--task",
        type=str,
        default="auto-speech-recognition",
        help="task name",
    )
    parser.add_argument(
        "--type",
        type=str,
        default="generic-asr",
    )
    parser.add_argument(
        "--am_model_name",
        type=str,
        default="model.pb",
        help="model file name",
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="paraformer",
        help="mode for decoding",
    )
    parser.add_argument(
        "--lang",
        type=str,
        default="zh-cn",
        help="language",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="batch size",
    )
    parser.add_argument(
        "--am_model_config",
        type=str,
        default="config.yaml",
        help="config file",
    )
    parser.add_argument(
        "--mvn_file",
        type=str,
        default="am.mvn",
        help="cmvn file",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="model name",
    )
    parser.add_argument(
        "--pipeline_type",
        type=str,
        default="asr-inference",
        help="pipeline type",
    )
    parser.add_argument(
        "--vocab_size",
        type=int,
        help="vocab_size",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        help="dataset name",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="output path",
    )
    parser.add_argument(
        "--nat",
        type=str,
        default="",
        help="nat",
    )
    parser.add_argument(
        "--tag",
        type=str,
        default="exp1",
        help="model name tag",
    )
    args = parser.parse_args()
    model = {
        "type": args.type,
        "am_model_name": args.am_model_name,
        "model_config": {
            "type": "pytorch",
            "code_base": "funasr",
            "mode": args.mode,
            "lang": args.lang,
            "batch_size": args.batch_size,
            "am_model_config": args.am_model_config,
            "mvn_file": args.mvn_file,
            "model": "speech_{}_asr{}-{}-16k-{}-vocab{}-pytorch-{}".format(args.model_name, args.nat, args.lang,
                                                                           args.dataset, args.vocab_size, args.tag),
        }
    }
    pipeline = {"type": args.pipeline_type}
    json_dict = {
        "framework": "pytorch",
        "task": args.task,
        "model": model,
        "pipeline": pipeline,
    }
    with open(os.path.join(args.output_dir, "configuration.json"), "w") as f:
        json.dump(json_dict, f, indent=4)