| | |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.models.decoder.rnnt_decoder import RNNTDecoder |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder |
| | | from funasr.models.e2e_asr import ASRModel |
| | | from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer |
| | | from funasr.models.e2e_asr_mfcca import MFCCA |
| | | |
| | | from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel |
| | | |
| | | from funasr.models.e2e_sa_asr import SAASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | | |
| | | from funasr.models.e2e_tp import TimestampPredictor |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.mfcca_encoder import MFCCAEncoder |
| | |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | |
| | | paraformer_bert=ParaformerBert, |
| | | bicif_paraformer=BiCifParaformer, |
| | | contextual_paraformer=ContextualParaformer, |
| | | neatcontextual_paraformer=NeatContextualParaformer, |
| | | mfcca=MFCCA, |
| | | timestamp_prediction=TimestampPredictor, |
| | | rnnt=TransducerModel, |
| | |
| | | |
| | | def build_asr_model(args): |
| | | # token_list |
| | | if args.token_list is not None: |
| | | with open(args.token_list) as f: |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | args.token_list = list(token_list) |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | elif isinstance(args.token_list, (tuple, list)): |
| | | token_list = list(args.token_list) |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | else: |
| | | token_list = None |
| | | vocab_size = None |
| | | |
| | | # frontend |
| | | if args.input_size is None: |
| | | if hasattr(args, "input_size") and args.input_size is None: |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | input_size = args.input_size if hasattr(args, "input_size") else None |
| | | |
| | | # data augmentation for spectrogram |
| | | if args.specaug is not None: |
| | |
| | | # normalization layer |
| | | if args.normalize is not None: |
| | | normalize_class = normalize_choices.get_class(args.normalize) |
| | | normalize = normalize_class(**args.normalize_conf) |
| | | if args.model == "mfcca": |
| | | normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf) |
| | | else: |
| | | normalize = normalize_class(**args.normalize_conf) |
| | | else: |
| | | normalize = None |
| | | |
| | |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]: |
| | | elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", |
| | | "contextual_paraformer", "neatcontextual_paraformer"]: |
| | | # predictor |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |