speech_asr
2023-04-20 3e77fd44304a67a2b2253b4e56fede9762bb8464
update
2个文件已修改
1个文件已添加
110 ■■■■■ 已修改文件
funasr/utils/build_asr_model.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_model.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_pretrain_model.py 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_asr_model.py
@@ -210,7 +210,6 @@
    # frontend
    if args.input_size is None:
        # Extract features in the model
        frontend_class = frontend_choices.get_class(args.frontend)
        if args.frontend == 'wav_frontend':
            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -218,7 +217,6 @@
            frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        # Give features from data-loader
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
funasr/utils/build_model.py
@@ -1,9 +1,12 @@
from funasr.utils.build_asr_model import build_asr_model
from funasr.utils.build_pretrain_model import build_pretrain_model
def build_model(args):
    if args.task_name == "asr":
        model = build_asr_model(args)
    elif args.task_name == "pretrain":
        model = build_pretrain_model(args)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))
funasr/utils/build_pretrain_model.py
New file
@@ -0,0 +1,105 @@
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.data2vec import Data2VecPretrainModel
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.specaug.specaug import SpecAug
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
frontend_choices = ClassChoices(
    name="frontend",
    classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
    default="default",
)
specaug_choices = ClassChoices(
    name="specaug",
    classes=dict(specaug=SpecAug),
    default=None,
    optional=True,
)
normalize_choices = ClassChoices(
    "normalize",
    classes=dict(
        global_mvn=GlobalMVN,
        utterance_mvn=UtteranceMVN,
    ),
    default=None,
    optional=True,
)
encoder_choices = ClassChoices(
    "encoder",
    classes=dict(
        data2vec_encoder=Data2VecEncoder,
    ),
    default="data2vec_encoder",
)
model_choices = ClassChoices(
    "model",
    classes=dict(
        data2vec=Data2VecPretrainModel,
    ),
    default="data2vec",
)
class_choices_list = [
    # --frontend and --frontend_conf
    frontend_choices,
    # --specaug and --specaug_conf
    specaug_choices,
    # --normalize and --normalize_conf
    normalize_choices,
    # --encoder and --encoder_conf
    encoder_choices,
    # --model and --model_conf
    model_choices,
]
def build_pretrain_model(args):
    if args.model_name == "data2vec":
        # frontend
        if args.input_size is None:
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None
        # normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None
        # encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(
            input_size=input_size,
            **args.encoder_conf,
        )
        model_class = model_choices.get_class("data2vec")
        model = model_class(
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
        )
        # 7. Initialize
        if args.init is not None:
            initialize(model, args.init)
        return model