From eac9f111b502e4581b14dc718731bf7dc1c7d5f6 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:59:26 +0800
Subject: [PATCH] update

---
 funasr/utils/build_asr_model.py      |    1 
 funasr/utils/build_lm_model.py       |   34 +++++++++++
 funasr/utils/build_model.py          |    2 
 funasr/utils/build_pretrain_model.py |   74 ++++++++++++------------
 funasr/utils/build_args.py           |   57 ++++++++++++++++++
 5 files changed, 131 insertions(+), 37 deletions(-)

diff --git a/funasr/utils/build_args.py b/funasr/utils/build_args.py
index 1baf2d6..f57f495 100644
--- a/funasr/utils/build_args.py
+++ b/funasr/utils/build_args.py
@@ -79,7 +79,62 @@
             default=None,
             help="The file path of noise scp file.",
         )
-
+    elif args.task_name == "pretrain":
+        from funasr.utils.build_pretrain_model import class_choices_list
+        for class_choices in class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(parser)
+        parser.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+        parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+        parser.add_argument(
+            "--feats_type",
+            type=str,
+            default='fbank',
+            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of noise decibel level.",
+        )
+        parser.add_argument(
+            "--pred_masked_weight",
+            type=float,
+            default=1.0,
+            help="weight for predictive loss for masked frames",
+        )
+        parser.add_argument(
+            "--pred_nomask_weight",
+            type=float,
+            default=0.0,
+            help="weight for predictive loss for unmasked frames",
+        )
+        parser.add_argument(
+            "--loss_weights",
+            type=float,
+            default=0.0,
+            help="weights for additional loss terms (not first one)",
+        )
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 
diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py
index e0275a0..e42637c 100644
--- a/funasr/utils/build_asr_model.py
+++ b/funasr/utils/build_asr_model.py
@@ -345,6 +345,7 @@
     else:
         raise NotImplementedError("Not supported model: {}".format(args.model))
 
+    # initialize
     if args.init is not None:
         initialize(model, args.init)
 
diff --git a/funasr/utils/build_lm_model.py b/funasr/utils/build_lm_model.py
new file mode 100644
index 0000000..4fe4625
--- /dev/null
+++ b/funasr/utils/build_lm_model.py
@@ -0,0 +1,34 @@
+from funasr.lm.abs_model import AbsLM
+from funasr.lm.seq_rnn_lm import SequentialRNNLM
+from funasr.lm.transformer_lm import TransformerLM
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+lm_choices = ClassChoices(
+    "lm",
+    classes=dict(
+        seq_rnn=SequentialRNNLM,
+        transformer=TransformerLM,
+    ),
+    type_check=AbsLM,
+    default="seq_rnn",
+)
+
+class_choices_list = [
+    # --lm and --lm_conf
+    lm_choices
+]
+
+
+def build_pretrain_model(args):
+    # token_list
+    if args.token_list is not None:
+        with open(args.token_list) as f:
+            token_list = [line.rstrip() for line in f]
+        args.token_list = list(token_list)
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+    else:
+        vocab_size = None
+
+    return model
diff --git a/funasr/utils/build_model.py b/funasr/utils/build_model.py
index 4b5b98a..b774304 100644
--- a/funasr/utils/build_model.py
+++ b/funasr/utils/build_model.py
@@ -7,6 +7,8 @@
         model = build_asr_model(args)
     elif args.task_name == "pretrain":
         model = build_pretrain_model(args)
+    elif args.task_name == "lm":
+        model = build_lm_model(args)
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 
diff --git a/funasr/utils/build_pretrain_model.py b/funasr/utils/build_pretrain_model.py
index e7554fa..e514215 100644
--- a/funasr/utils/build_pretrain_model.py
+++ b/funasr/utils/build_pretrain_model.py
@@ -57,39 +57,39 @@
 
 
 def build_pretrain_model(args):
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # data augmentation for spectrogram
+    if args.specaug is not None:
+        specaug_class = specaug_choices.get_class(args.specaug)
+        specaug = specaug_class(**args.specaug_conf)
+    else:
+        specaug = None
+
+    # normalization layer
+    if args.normalize is not None:
+        normalize_class = normalize_choices.get_class(args.normalize)
+        normalize = normalize_class(**args.normalize_conf)
+    else:
+        normalize = None
+
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(
+        input_size=input_size,
+        **args.encoder_conf,
+    )
+
     if args.model_name == "data2vec":
-        # frontend
-        if args.input_size is None:
-            frontend_class = frontend_choices.get_class(args.frontend)
-            frontend = frontend_class(**args.frontend_conf)
-            input_size = frontend.output_size()
-        else:
-            args.frontend = None
-            args.frontend_conf = {}
-            frontend = None
-            input_size = args.input_size
-
-        # data augmentation for spectrogram
-        if args.specaug is not None:
-            specaug_class = specaug_choices.get_class(args.specaug)
-            specaug = specaug_class(**args.specaug_conf)
-        else:
-            specaug = None
-
-        # normalization layer
-        if args.normalize is not None:
-            normalize_class = normalize_choices.get_class(args.normalize)
-            normalize = normalize_class(**args.normalize_conf)
-        else:
-            normalize = None
-
-        # encoder
-        encoder_class = encoder_choices.get_class(args.encoder)
-        encoder = encoder_class(
-            input_size=input_size,
-            **args.encoder_conf,
-        )
-
         model_class = model_choices.get_class("data2vec")
         model = model_class(
             frontend=frontend,
@@ -97,9 +97,11 @@
             normalize=normalize,
             encoder=encoder,
         )
+    else:
+        raise NotImplementedError("Not supported model: {}".format(args.model))
 
-        # 7. Initialize
-        if args.init is not None:
-            initialize(model, args.init)
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
 
-        return model
+    return model

--
Gitblit v1.9.1