From eac9f111b502e4581b14dc718731bf7dc1c7d5f6 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:59:26 +0800
Subject: [PATCH] update
---
funasr/utils/build_asr_model.py | 1
funasr/utils/build_lm_model.py | 34 +++++++++++
funasr/utils/build_model.py | 2
funasr/utils/build_pretrain_model.py | 74 ++++++++++++------------
funasr/utils/build_args.py | 57 ++++++++++++++++++
5 files changed, 131 insertions(+), 37 deletions(-)
diff --git a/funasr/utils/build_args.py b/funasr/utils/build_args.py
index 1baf2d6..f57f495 100644
--- a/funasr/utils/build_args.py
+++ b/funasr/utils/build_args.py
@@ -79,7 +79,62 @@
default=None,
help="The file path of noise scp file.",
)
-
+ elif args.task_name == "pretrain":
+ from funasr.utils.build_pretrain_model import class_choices_list
+ for class_choices in class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(parser)
+ parser.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+ parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+ parser.add_argument(
+ "--feats_type",
+ type=str,
+ default='fbank',
+ help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of noise decibel level.",
+ )
+ parser.add_argument(
+ "--pred_masked_weight",
+ type=float,
+ default=1.0,
+ help="weight for predictive loss for masked frames",
+ )
+ parser.add_argument(
+ "--pred_nomask_weight",
+ type=float,
+ default=0.0,
+ help="weight for predictive loss for unmasked frames",
+ )
+ parser.add_argument(
+ "--loss_weights",
+ type=float,
+ default=0.0,
+ help="weights for additional loss terms (not first one)",
+ )
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py
index e0275a0..e42637c 100644
--- a/funasr/utils/build_asr_model.py
+++ b/funasr/utils/build_asr_model.py
@@ -345,6 +345,7 @@
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
+ # initialize
if args.init is not None:
initialize(model, args.init)
diff --git a/funasr/utils/build_lm_model.py b/funasr/utils/build_lm_model.py
new file mode 100644
index 0000000..4fe4625
--- /dev/null
+++ b/funasr/utils/build_lm_model.py
@@ -0,0 +1,34 @@
+from funasr.lm.abs_model import AbsLM
+from funasr.lm.seq_rnn_lm import SequentialRNNLM
+from funasr.lm.transformer_lm import TransformerLM
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+lm_choices = ClassChoices(
+ "lm",
+ classes=dict(
+ seq_rnn=SequentialRNNLM,
+ transformer=TransformerLM,
+ ),
+ type_check=AbsLM,
+ default="seq_rnn",
+)
+
+class_choices_list = [
+ # --lm and --lm_conf
+ lm_choices
+]
+
+
+def build_pretrain_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ return model
diff --git a/funasr/utils/build_model.py b/funasr/utils/build_model.py
index 4b5b98a..b774304 100644
--- a/funasr/utils/build_model.py
+++ b/funasr/utils/build_model.py
@@ -7,6 +7,8 @@
model = build_asr_model(args)
elif args.task_name == "pretrain":
model = build_pretrain_model(args)
+ elif args.task_name == "lm":
+ model = build_lm_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/utils/build_pretrain_model.py b/funasr/utils/build_pretrain_model.py
index e7554fa..e514215 100644
--- a/funasr/utils/build_pretrain_model.py
+++ b/funasr/utils/build_pretrain_model.py
@@ -57,39 +57,39 @@
def build_pretrain_model(args):
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(
+ input_size=input_size,
+ **args.encoder_conf,
+ )
+
if args.model_name == "data2vec":
- # frontend
- if args.input_size is None:
- frontend_class = frontend_choices.get_class(args.frontend)
- frontend = frontend_class(**args.frontend_conf)
- input_size = frontend.output_size()
- else:
- args.frontend = None
- args.frontend_conf = {}
- frontend = None
- input_size = args.input_size
-
- # data augmentation for spectrogram
- if args.specaug is not None:
- specaug_class = specaug_choices.get_class(args.specaug)
- specaug = specaug_class(**args.specaug_conf)
- else:
- specaug = None
-
- # normalization layer
- if args.normalize is not None:
- normalize_class = normalize_choices.get_class(args.normalize)
- normalize = normalize_class(**args.normalize_conf)
- else:
- normalize = None
-
- # encoder
- encoder_class = encoder_choices.get_class(args.encoder)
- encoder = encoder_class(
- input_size=input_size,
- **args.encoder_conf,
- )
-
model_class = model_choices.get_class("data2vec")
model = model_class(
frontend=frontend,
@@ -97,9 +97,11 @@
normalize=normalize,
encoder=encoder,
)
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
- # 7. Initialize
- if args.init is not None:
- initialize(model, args.init)
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
- return model
+ return model
--
Gitblit v1.9.1