From 3e77fd44304a67a2b2253b4e56fede9762bb8464 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:41:22 +0800
Subject: [PATCH] update

---
 funasr/utils/build_asr_model.py      |    2 -
 funasr/utils/build_model.py          |    3 +
 funasr/utils/build_pretrain_model.py |  105 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 108 insertions(+), 2 deletions(-)

diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py
index f333969..e0275a0 100644
--- a/funasr/utils/build_asr_model.py
+++ b/funasr/utils/build_asr_model.py
@@ -210,7 +210,6 @@
 
     # frontend
     if args.input_size is None:
-        # Extract features in the model
         frontend_class = frontend_choices.get_class(args.frontend)
         if args.frontend == 'wav_frontend':
             frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -218,7 +217,6 @@
             frontend = frontend_class(**args.frontend_conf)
         input_size = frontend.output_size()
     else:
-        # Give features from data-loader
         args.frontend = None
         args.frontend_conf = {}
         frontend = None
diff --git a/funasr/utils/build_model.py b/funasr/utils/build_model.py
index b85b8dc..4b5b98a 100644
--- a/funasr/utils/build_model.py
+++ b/funasr/utils/build_model.py
@@ -1,9 +1,12 @@
 from funasr.utils.build_asr_model import build_asr_model
+from funasr.utils.build_pretrain_model import build_pretrain_model
 
 
 def build_model(args):
     if args.task_name == "asr":
         model = build_asr_model(args)
+    elif args.task_name == "pretrain":
+        model = build_pretrain_model(args)
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 
diff --git a/funasr/utils/build_pretrain_model.py b/funasr/utils/build_pretrain_model.py
new file mode 100644
index 0000000..e7554fa
--- /dev/null
+++ b/funasr/utils/build_pretrain_model.py
@@ -0,0 +1,105 @@
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(specaug=SpecAug),
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    default=None,
+    optional=True,
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        data2vec_encoder=Data2VecEncoder,
+    ),
+    default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        data2vec=Data2VecPretrainModel,
+    ),
+    default="data2vec",
+)
+class_choices_list = [
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --specaug and --specaug_conf
+    specaug_choices,
+    # --normalize and --normalize_conf
+    normalize_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --model and --model_conf
+    model_choices,
+]
+
+
+def build_pretrain_model(args):
+    if args.model_name == "data2vec":
+        # frontend
+        if args.input_size is None:
+            frontend_class = frontend_choices.get_class(args.frontend)
+            frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+
+        # data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # encoder
+        encoder_class = encoder_choices.get_class(args.encoder)
+        encoder = encoder_class(
+            input_size=input_size,
+            **args.encoder_conf,
+        )
+
+        model_class = model_choices.get_class("data2vec")
+        model = model_class(
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+        )
+
+        # 7. Initialize
+        if args.init is not None:
+            initialize(model, args.init)
+
+        return model

--
Gitblit v1.9.1