From cc2c1d1d53dea5d2c45f858d1baa5bd279f47987 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 31 五月 2023 14:39:25 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/build_utils/build_lm_model.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 57 insertions(+), 0 deletions(-)
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
new file mode 100644
index 0000000..8f4a958
--- /dev/null
+++ b/funasr/build_utils/build_lm_model.py
@@ -0,0 +1,57 @@
+import logging
+
+from funasr.train.abs_model import AbsLM
+from funasr.train.abs_model import LanguageModel
+from funasr.models.seq_rnn_lm import SequentialRNNLM
+from funasr.models.transformer_lm import TransformerLM
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+lm_choices = ClassChoices(
+ "lm",
+ classes=dict(
+ seq_rnn=SequentialRNNLM,
+ transformer=TransformerLM,
+ ),
+ type_check=AbsLM,
+ default="seq_rnn",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ lm=LanguageModel,
+ ),
+ default="lm",
+)
+
+class_choices_list = [
+ # --lm and --lm_conf
+ lm_choices,
+ # --model and --model_conf
+ model_choices
+]
+
+
+def build_lm_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ # lm
+ lm_class = lm_choices.get_class(args.lm)
+ lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
--
Gitblit v1.9.1