From 9b7311209e0a991e5afae2bc0e416cdcc78c18d5 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 24 五月 2023 19:10:09 +0800
Subject: [PATCH] update repo

---
 egs/aishell/paraformer/run.sh                                 |   11 +++
 egs/aishell/transformer/utils/gen_modelscope_configuration.py |  117 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 128 insertions(+), 0 deletions(-)

diff --git a/egs/aishell/paraformer/run.sh b/egs/aishell/paraformer/run.sh
index 373c119..0c6b73c 100755
--- a/egs/aishell/paraformer/run.sh
+++ b/egs/aishell/paraformer/run.sh
@@ -101,6 +101,7 @@
     utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
     echo "<unk>" >> ${token_list}
+    vocab_size=$(cat ${token_list} | wc -l)
 fi
 
 # LM Training Stage
@@ -213,4 +214,14 @@
 # Prepare files for ModelScope fine-tuning and inference
 if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
     echo "stage 6: ModelScope Preparation"
+    cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
+    python utils/gen_modelscope_configuration.py \
+        --am_model_file $inference_asr_model \
+        --mode paraformer \
+        --model_name paraformer \
+        --model $model_dir \
+        --output_dir $exp_dir/exp/$model_dir \
+        --vocab_size $vocab_size \
+        --nar _nat \
+        --tag $tag
 fi
\ No newline at end of file
diff --git a/egs/aishell/transformer/utils/gen_modelscope_configuration.py b/egs/aishell/transformer/utils/gen_modelscope_configuration.py
new file mode 100644
index 0000000..8414647
--- /dev/null
+++ b/egs/aishell/transformer/utils/gen_modelscope_configuration.py
@@ -0,0 +1,117 @@
+import argparse
+import json
+import os
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--task",
+        type=str,
+        default="auto-speech-recognition",
+        help="task name",
+    )
+    parser.add_argument(
+        "--type",
+        type=str,
+        default="generic-asr",
+    )
+    parser.add_argument(
+        "--am_model_name",
+        type=str,
+        default="model.pb",
+        help="model file name",
+    )
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default="paraformer",
+        help="mode for decoding",
+    )
+    parser.add_argument(
+        "--lang",
+        type=str,
+        default="zh-cn",
+        help="language",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="batch size",
+    )
+    parser.add_argument(
+        "--am_model_config",
+        type=str,
+        default="config.yaml",
+        help="config file",
+    )
+    parser.add_argument(
+        "--mvn_file",
+        type=str,
+        default="am.mvn",
+        help="cmvn file",
+    )
+    parser.add_argument(
+        "--model_name",
+        type=str,
+        help="model name",
+    )
+    parser.add_argument(
+        "--pipeline_type",
+        type=str,
+        default="asr-inference",
+        help="pipeline type",
+    )
+    parser.add_argument(
+        "--vocab_size",
+        type=int,
+        help="vocab_size",
+    )
+    parser.add_argument(
+        "--dataset",
+        type=str,
+        help="dataset name",
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        help="output path",
+    )
+    parser.add_argument(
+        "--nat",
+        type=str,
+        default="",
+        help="vocab_size",
+    )
+    parser.add_argument(
+        "--tag",
+        type=str,
+        default="exp1",
+        help="model name tag",
+    )
+    args = parser.parse_args()
+
+    model = {
+        "type": args.type,
+        "am_model_name": args.am_model_name,
+        "model_config": {
+            "type": "pytorch",
+            "code_base": "funasr",
+            "mode": args.mode,
+            "lang": args.lang,
+            "batch_size": args.batch_size,
+            "am_model_config": args.am_model_config,
+            "mvn_file": args.mvn_file,
+            "model": "speech_{}_asr{}-{}-16k-{}-vocab{}-pytorch-{}".format(args.model_name, args.nat, args.lang,
+                                                                           args.dataset, args.vocab_size, args.tag),
+        }
+    }
+    json_dict = {
+        "model": model,
+        "framework": "pytorch",
+        "task": args.task,
+        "pipeline": args.pipeline_type,
+    }
+
+    with open(os.path.join(args.output_dir, "configuration.json"), "w") as f:
+        json.dump(json_dict, f)

--
Gitblit v1.9.1