| | |
| | | |
| | | # frontend |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | elif args.model == "paraformer": |
| | | elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]: |
| | | # predictor |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |
| | |
| | | stride_conv=stride_conv, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | elif args.model == "timestamp_prediction": |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | frontend=frontend, |
| | | encoder=encoder, |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Not supported model: {}".format(args.model)) |
| | | |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | return model |