jmwang66
2023-08-07 cf8e000a84e888495dcf30c4dbfecea1ee7ab4e2
funasr/bin/build_trainer.py
@@ -2,7 +2,6 @@
import yaml
def update_dct(fin_configs, root):
    if root == {}:
        return {}
@@ -23,10 +22,16 @@
        from funasr.tasks.asr import ASRTask as ASRTask
    elif mode == "paraformer":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "paraformer_streaming":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "paraformer_vad_punc":
        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
    elif mode == "uniasr":
        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
    elif mode == "mfcca":
        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
    elif mode == "tp":
        from funasr.tasks.asr import ASRTaskAligner as ASRTask
    else:
        raise ValueError("Unknown mode: {}".format(mode))
    parser = ASRTask.get_parser()
@@ -34,8 +39,23 @@
    return args, ASRTask
def build_trainer(modelscope_dict, data_dir, output_dir, train_set="train", dev_set="validation", distributed=False,
                  dataset_type="small", lr=None, batch_bins=None, max_epoch=None, mate_params=None):
def build_trainer(modelscope_dict,
                  data_dir,
                  output_dir,
                  train_set="train",
                  dev_set="validation",
                  distributed=False,
                  dataset_type="small",
                  batch_bins=None,
                  max_epoch=None,
                  optim=None,
                  lr=None,
                  scheduler=None,
                  scheduler_conf=None,
                  specaug=None,
                  specaug_conf=None,
                  mate_params=None,
                  **kwargs):
    mode = modelscope_dict['mode']
    args, ASRTask = parse_args(mode=mode)
    # ddp related
@@ -64,11 +84,21 @@
        finetune_configs = yaml.safe_load(f)
        # set data_types
        if dataset_type == "large":
            finetune_configs["dataset_conf"]["data_types"] = "sound,text"
            # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
            if 'data_types' not in finetune_configs['dataset_conf']:
                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
    finetune_configs = update_dct(configs, finetune_configs)
    for key, value in finetune_configs.items():
        if hasattr(args, key):
            setattr(args, key, value)
    if mate_params is not None:
        for key, value in mate_params.items():
            if hasattr(args, key):
                setattr(args, key, value)
    if mate_params is not None and "lora_params" in mate_params:
        lora_params = mate_params['lora_params']
        configs['encoder_conf'].update(lora_params)
        configs['decoder_conf'].update(lora_params)
    # prepare data
    args.dataset_type = dataset_type
@@ -83,6 +113,9 @@
    else:
        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
    args.init_param = [init_param]
    if mate_params is not None and "init_param" in mate_params:
        if len(mate_params["init_param"]) != 0:
            args.init_param = mate_params["init_param"]
    args.cmvn_file = cmvn_file
    if os.path.exists(seg_dict_file):
        args.seg_dict_file = seg_dict_file
@@ -94,8 +127,18 @@
    args.output_dir = output_dir
    args.gpu_id = args.local_rank
    args.config = finetune_config
    if optim is not None:
        args.optim = optim
    if lr is not None:
        args.optim_conf["lr"] = lr
    if scheduler is not None:
        args.scheduler = scheduler
    if scheduler_conf is not None:
        args.scheduler_conf = scheduler_conf
    if specaug is not None:
        args.specaug = specaug
    if specaug_conf is not None:
        args.specaug_conf = specaug_conf
    if max_epoch is not None:
        args.max_epoch = max_epoch
    if batch_bins is not None: