| | |
| | | default=None, |
| | | help="The file path of noise scp file.", |
| | | ) |
| | | |
| | | elif args.task_name == "pretrain": |
| | | from funasr.utils.build_pretrain_model import class_choices_list |
| | | for class_choices in class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(parser) |
| | | parser.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | parser.add_argument( |
| | | "--input_size", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="The number of input dimension of the feature", |
| | | ) |
| | | parser.add_argument( |
| | | "--feats_type", |
| | | type=str, |
| | | default='fbank', |
| | | help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_db_range", |
| | | type=str, |
| | | default="13_15", |
| | | help="The range of noise decibel level.", |
| | | ) |
| | | parser.add_argument( |
| | | "--pred_masked_weight", |
| | | type=float, |
| | | default=1.0, |
| | | help="weight for predictive loss for masked frames", |
| | | ) |
| | | parser.add_argument( |
| | | "--pred_nomask_weight", |
| | | type=float, |
| | | default=0.0, |
| | | help="weight for predictive loss for unmasked frames", |
| | | ) |
| | | parser.add_argument( |
| | | "--loss_weights", |
| | | type=float, |
| | | default=0.0, |
| | | help="weights for additional loss terms (not first one)", |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Not supported task: {}".format(args.task_name)) |
| | | |
| | |
| | | else: |
| | | raise NotImplementedError("Not supported model: {}".format(args.model)) |
| | | |
| | | # initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| New file |
| | |
| | | from funasr.lm.abs_model import AbsLM |
| | | from funasr.lm.seq_rnn_lm import SequentialRNNLM |
| | | from funasr.lm.transformer_lm import TransformerLM |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.class_choices import ClassChoices |
| | | |
| | | lm_choices = ClassChoices( |
| | | "lm", |
| | | classes=dict( |
| | | seq_rnn=SequentialRNNLM, |
| | | transformer=TransformerLM, |
| | | ), |
| | | type_check=AbsLM, |
| | | default="seq_rnn", |
| | | ) |
| | | |
| | | class_choices_list = [ |
| | | # --lm and --lm_conf |
| | | lm_choices |
| | | ] |
| | | |
| | | |
| | | def build_pretrain_model(args): |
| | | # token_list |
| | | if args.token_list is not None: |
| | | with open(args.token_list) as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | args.token_list = list(token_list) |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | else: |
| | | vocab_size = None |
| | | |
| | | return model |
| | |
| | | model = build_asr_model(args) |
| | | elif args.task_name == "pretrain": |
| | | model = build_pretrain_model(args) |
| | | elif args.task_name == "lm": |
| | | model = build_lm_model(args) |
| | | else: |
| | | raise NotImplementedError("Not supported task: {}".format(args.task_name)) |
| | | |
| | |
| | | |
| | | |
| | | def build_pretrain_model(args): |
| | | # frontend |
| | | if args.input_size is None: |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # data augmentation for spectrogram |
| | | if args.specaug is not None: |
| | | specaug_class = specaug_choices.get_class(args.specaug) |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # normalization layer |
| | | if args.normalize is not None: |
| | | normalize_class = normalize_choices.get_class(args.normalize) |
| | | normalize = normalize_class(**args.normalize_conf) |
| | | else: |
| | | normalize = None |
| | | |
| | | # encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class( |
| | | input_size=input_size, |
| | | **args.encoder_conf, |
| | | ) |
| | | |
| | | if args.model_name == "data2vec": |
| | | # frontend |
| | | if args.input_size is None: |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # data augmentation for spectrogram |
| | | if args.specaug is not None: |
| | | specaug_class = specaug_choices.get_class(args.specaug) |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # normalization layer |
| | | if args.normalize is not None: |
| | | normalize_class = normalize_choices.get_class(args.normalize) |
| | | normalize = normalize_class(**args.normalize_conf) |
| | | else: |
| | | normalize = None |
| | | |
| | | # encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class( |
| | | input_size=input_size, |
| | | **args.encoder_conf, |
| | | ) |
| | | |
| | | model_class = model_choices.get_class("data2vec") |
| | | model = model_class( |
| | | frontend=frontend, |
| | |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Not supported model: {}".format(args.model)) |
| | | |
| | | # 7. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | # initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | return model |
| | | return model |