游雁
2023-05-16 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0
train
2个文件已修改
60 ■■■■■ 已修改文件
funasr/bin/asr_train.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 50 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_train.py
@@ -34,6 +34,8 @@
        from funasr.tasks.asr import ASRTask
    if args.mode == "paraformer":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    if args.mode == "uniasr":
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
    ASRTask.main(args=args, cmd=cmd)
@@ -42,8 +44,7 @@
    args = parse_args()
    # setup local gpu_id
    if args.ngpu > 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    # DDP settings
    if args.ngpu > 1:
@@ -54,9 +55,10 @@
    # re-compute batch size: when dataset type is small
    if args.dataset_type == "small":
        if args.batch_size is not None and args.ngpu > 0:
        if args.batch_size is not None:
            args.batch_size = args.batch_size * args.ngpu
        if args.batch_bins is not None and args.ngpu > 0:
        if args.batch_bins is not None:
            args.batch_bins = args.batch_bins * args.ngpu
    main(args=args)
funasr/tasks/asr.py
@@ -282,6 +282,14 @@
        decoder_choices,
        # --predictor and --predictor_conf
        predictor_choices,
        # --encoder2 and --encoder2_conf
        encoder_choices2,
        # --decoder2 and --decoder2_conf
        decoder_choices2,
        # --predictor2 and --predictor2_conf
        predictor_choices2,
        # --stride_conv and --stride_conv_conf
        stride_conv_choices,
    ]
    # If you need to modify train() or eval() procedures, change Trainer class here
@@ -901,27 +909,27 @@
    # If you need more than one optimizers, change this value
    num_optimizers: int = 1
    # Add variable objects configurations
    class_choices_list = [
        # --frontend and --frontend_conf
        frontend_choices,
        # --specaug and --specaug_conf
        specaug_choices,
        # --normalize and --normalize_conf
        normalize_choices,
        # --model and --model_conf
        model_choices,
        # --preencoder and --preencoder_conf
        preencoder_choices,
        # --encoder and --encoder_conf
        encoder_choices,
        # --postencoder and --postencoder_conf
        postencoder_choices,
        # --decoder and --decoder_conf
        decoder_choices,
        # --predictor and --predictor_conf
        predictor_choices,
    ]
    # # Add variable objects configurations
    # class_choices_list = [
    #     # --frontend and --frontend_conf
    #     frontend_choices,
    #     # --specaug and --specaug_conf
    #     specaug_choices,
    #     # --normalize and --normalize_conf
    #     normalize_choices,
    #     # --model and --model_conf
    #     model_choices,
    #     # --preencoder and --preencoder_conf
    #     preencoder_choices,
    #     # --encoder and --encoder_conf
    #     encoder_choices,
    #     # --postencoder and --postencoder_conf
    #     postencoder_choices,
    #     # --decoder and --decoder_conf
    #     decoder_choices,
    #     # --predictor and --predictor_conf
    #     predictor_choices,
    # ]
    # If you need to modify train() or eval() procedures, change Trainer class here
    trainer = Trainer