嘉渊
2023-05-24 9b7311209e0a991e5afae2bc0e416cdcc78c18d5
update repo
1个文件已修改
1个文件已添加
128 ■■■■■ 已修改文件
egs/aishell/paraformer/run.sh 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/transformer/utils/gen_modelscope_configuration.py 117 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/aishell/paraformer/run.sh
@@ -101,6 +101,7 @@
    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
    echo "<unk>" >> ${token_list}
    vocab_size=$(cat ${token_list} | wc -l)
fi
# LM Training Stage
@@ -213,4 +214,14 @@
# Prepare files for ModelScope fine-tuning and inference
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
    echo "stage 6: ModelScope Preparation"
    cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
    python utils/gen_modelscope_configuration.py \
        --am_model_file $inference_asr_model \
        --mode paraformer \
        --model_name paraformer \
        --model $model_dir \
        --output_dir $exp_dir/exp/$model_dir \
        --vocab_size $vocab_size \
        --nar _nat \
        --tag $tag
fi
egs/aishell/transformer/utils/gen_modelscope_configuration.py
New file
@@ -0,0 +1,117 @@
import argparse
import json
import os
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--task",
        type=str,
        default="auto-speech-recognition",
        help="task name",
    )
    parser.add_argument(
        "--type",
        type=str,
        default="generic-asr",
    )
    parser.add_argument(
        "--am_model_name",
        type=str,
        default="model.pb",
        help="model file name",
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="paraformer",
        help="mode for decoding",
    )
    parser.add_argument(
        "--lang",
        type=str,
        default="zh-cn",
        help="language",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="batch size",
    )
    parser.add_argument(
        "--am_model_config",
        type=str,
        default="config.yaml",
        help="config file",
    )
    parser.add_argument(
        "--mvn_file",
        type=str,
        default="am.mvn",
        help="cmvn file",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="model name",
    )
    parser.add_argument(
        "--pipeline_type",
        type=str,
        default="asr-inference",
        help="pipeline type",
    )
    parser.add_argument(
        "--vocab_size",
        type=int,
        help="vocab_size",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        help="dataset name",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        help="output path",
    )
    parser.add_argument(
        "--nat",
        type=str,
        default="",
        help="vocab_size",
    )
    parser.add_argument(
        "--tag",
        type=str,
        default="exp1",
        help="model name tag",
    )
    args = parser.parse_args()
    model = {
        "type": args.type,
        "am_model_name": args.am_model_name,
        "model_config": {
            "type": "pytorch",
            "code_base": "funasr",
            "mode": args.mode,
            "lang": args.lang,
            "batch_size": args.batch_size,
            "am_model_config": args.am_model_config,
            "mvn_file": args.mvn_file,
            "model": "speech_{}_asr{}-{}-16k-{}-vocab{}-pytorch-{}".format(args.model_name, args.nat, args.lang,
                                                                           args.dataset, args.vocab_size, args.tag),
        }
    }
    json_dict = {
        "model": model,
        "framework": "pytorch",
        "task": args.task,
        "pipeline": args.pipeline_type,
    }
    with open(os.path.join(args.output_dir, "configuration.json"), "w") as f:
        json.dump(json_dict, f)