| | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | from pathlib import Path |
| | | from typing import Callable |
| | | from typing import Collection |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | import os |
| | | from pathlib import Path |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import yaml |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import yaml |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.decoder.rnn_decoder import RNNDecoder |
| | | from funasr.models.decoder.transformer_decoder import ( |
| | | DynamicConvolution2DTransformerDecoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder |
| | | from funasr.models.decoder.transformer_decoder import ( |
| | | LightweightConvolution2DTransformerDecoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.decoder.transformer_decoder import ( |
| | | LightweightConvolutionTransformerDecoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.postencoder.hugging_face_transformers_postencoder import ( |
| | | HuggingFaceTransformersPostEncoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.preencoder.linear import LinearProjection |
| | | from funasr.models.preencoder.sinc import LightweightSincConvs |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.trainer import Trainer |
| | | from funasr.utils.get_default_kwargs import get_default_kwargs |
| | | from funasr.utils.nested_dict_action import NestedDictAction |
| | | from funasr.utils.types import float_or_none |
| | | from funasr.utils.types import int_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2 |
| | | from funasr.modules.subsampling import Conv1dSubsampling |
| | | from funasr.models.e2e_vad import E2EVadModel |
| | | from funasr.models.encoder.fsmn_encoder import FSMN |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.trainer import Trainer |
| | | from funasr.utils.types import float_or_none |
| | | from funasr.utils.types import int_or_none |
| | | from funasr.utils.types import str_or_none |
| | | |
| | | frontend_choices = ClassChoices( |
| | | name="frontend", |
| | |
| | | model_class = model_choices.get_class(args.model) |
| | | except AttributeError: |
| | | model_class = model_choices.get_class("e2evad") |
| | | |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | |
| | | model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend) |
| | | |
| | | return model |
| | |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | #if cmvn_file is not None: |
| | | # if cmvn_file is not None: |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |