From eac9f111b502e4581b14dc718731bf7dc1c7d5f6 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:59:26 +0800
Subject: [PATCH] update

---
 funasr/utils/build_pretrain_model.py |   74 +++++++++++++++++++------------------
 1 files changed, 38 insertions(+), 36 deletions(-)

diff --git a/funasr/utils/build_pretrain_model.py b/funasr/utils/build_pretrain_model.py
index e7554fa..e514215 100644
--- a/funasr/utils/build_pretrain_model.py
+++ b/funasr/utils/build_pretrain_model.py
@@ -57,39 +57,39 @@
 
 
 def build_pretrain_model(args):
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # data augmentation for spectrogram
+    if args.specaug is not None:
+        specaug_class = specaug_choices.get_class(args.specaug)
+        specaug = specaug_class(**args.specaug_conf)
+    else:
+        specaug = None
+
+    # normalization layer
+    if args.normalize is not None:
+        normalize_class = normalize_choices.get_class(args.normalize)
+        normalize = normalize_class(**args.normalize_conf)
+    else:
+        normalize = None
+
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(
+        input_size=input_size,
+        **args.encoder_conf,
+    )
+
     if args.model_name == "data2vec":
-        # frontend
-        if args.input_size is None:
-            frontend_class = frontend_choices.get_class(args.frontend)
-            frontend = frontend_class(**args.frontend_conf)
-            input_size = frontend.output_size()
-        else:
-            args.frontend = None
-            args.frontend_conf = {}
-            frontend = None
-            input_size = args.input_size
-
-        # data augmentation for spectrogram
-        if args.specaug is not None:
-            specaug_class = specaug_choices.get_class(args.specaug)
-            specaug = specaug_class(**args.specaug_conf)
-        else:
-            specaug = None
-
-        # normalization layer
-        if args.normalize is not None:
-            normalize_class = normalize_choices.get_class(args.normalize)
-            normalize = normalize_class(**args.normalize_conf)
-        else:
-            normalize = None
-
-        # encoder
-        encoder_class = encoder_choices.get_class(args.encoder)
-        encoder = encoder_class(
-            input_size=input_size,
-            **args.encoder_conf,
-        )
-
         model_class = model_choices.get_class("data2vec")
         model = model_class(
             frontend=frontend,
@@ -97,9 +97,11 @@
             normalize=normalize,
             encoder=encoder,
         )
+    else:
+        raise NotImplementedError("Not supported model: {}".format(args.model))
 
-        # 7. Initialize
-        if args.init is not None:
-            initialize(model, args.init)
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
 
-        return model
+    return model

--
Gitblit v1.9.1