From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2

---
 funasr/build_utils/build_punc_model.py |   68 ++++++++++++++++++++++++++++++++++
 1 files changed, 68 insertions(+), 0 deletions(-)

diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py
new file mode 100644
index 0000000..62ccaf2
--- /dev/null
+++ b/funasr/build_utils/build_punc_model.py
@@ -0,0 +1,68 @@
+import logging
+
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_model import PunctuationModel
+from funasr.train.class_choices import ClassChoices
+
+punc_choices = ClassChoices(
+    "punctuation",
+    classes=dict(
+        target_delay=TargetDelayTransformer,
+        vad_realtime=VadRealtimeTransformer
+    ),
+    default="target_delay",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        punc=PunctuationModel,
+    ),
+    default="punc",
+)
+class_choices_list = [
+    # --punc and --punc_conf
+    punc_choices,
+    # --model and --model_conf
+    model_choices
+]
+
+
+def build_punc_model(args):
+    # token_list and punc list
+    if isinstance(args.token_list, str):
+        with open(args.token_list, encoding="utf-8") as f:
+            token_list = [line.rstrip() for line in f]
+        args.token_list = token_list.copy()
+    if isinstance(args.punc_list, str):
+        with open(args.punc_list, encoding="utf-8") as f2:
+            pairs = [line.rstrip().split(":") for line in f2]
+        punc_list = [pair[0] for pair in pairs]
+        punc_weight_list = [float(pair[1]) for pair in pairs]
+        args.punc_list = punc_list.copy()
+    elif isinstance(args.punc_list, list):
+        punc_list = args.punc_list.copy()
+        punc_weight_list = [1] * len(punc_list)
+    if isinstance(args.token_list, (tuple, list)):
+        token_list = args.token_list.copy()
+    else:
+        raise RuntimeError("token_list must be str or dict")
+
+    vocab_size = len(token_list)
+    punc_size = len(punc_list)
+    logging.info(f"Vocabulary size: {vocab_size}")
+
+    # punc
+    punc_class = punc_choices.get_class(args.punctuation)
+    punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
+
+    if "punc_weight" in args.model_conf:
+        args.model_conf.pop("punc_weight")
+    model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model

--
Gitblit v1.9.1