speech_asr
2023-04-20 3e77fd44304a67a2b2253b4e56fede9762bb8464
funasr/utils/build_asr_model.py
@@ -210,7 +210,6 @@
    # frontend
    if args.input_size is None:
        # Extract features in the model
        frontend_class = frontend_choices.get_class(args.frontend)
        if args.frontend == 'wav_frontend':
            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -218,7 +217,6 @@
            frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        # Give features from data-loader
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
@@ -268,7 +266,7 @@
            token_list=token_list,
            **args.model_conf,
        )
    elif args.model == "paraformer":
    elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
@@ -336,9 +334,18 @@
            stride_conv=stride_conv,
            **args.model_conf,
        )
    elif args.model == "timestamp_prediction":
        model_class = model_choices.get_class(args.model)
        model = model_class(
            frontend=frontend,
            encoder=encoder,
            token_list=token_list,
            **args.model_conf,
        )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
    if args.init is not None:
        initialize(model, args.init)
    return model