From 7ea3836893bfdf1aac03952bb1ff2da2c6ef6e57 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 01 八月 2023 14:18:32 +0800
Subject: [PATCH] update

---
 /dev/null |  159 -----------------------------------------------------
 1 files changed, 0 insertions(+), 159 deletions(-)

diff --git a/funasr/bin/build_trainer_bak.py b/funasr/bin/build_trainer_bak.py
deleted file mode 100644
index e7f28ed..0000000
--- a/funasr/bin/build_trainer_bak.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import os
-
-import yaml
-
-def update_dct(fin_configs, root):
-    if root == {}:
-        return {}
-    for root_key, root_value in root.items():
-        if not isinstance(root[root_key], dict):
-            fin_configs[root_key] = root[root_key]
-        else:
-            if root_key in fin_configs.keys():
-                result = update_dct(fin_configs[root_key], root[root_key])
-                fin_configs[root_key] = result
-            else:
-                fin_configs[root_key] = root[root_key]
-    return fin_configs
-
-
-def parse_args(mode):
-    if mode == "asr":
-        from funasr.tasks.asr import ASRTask as ASRTask
-    elif mode == "paraformer":
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-    elif mode == "paraformer_streaming":
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-    elif mode == "paraformer_vad_punc":
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-    elif mode == "uniasr":
-        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
-    elif mode == "mfcca":
-        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
-    elif mode == "tp":
-        from funasr.tasks.asr import ASRTaskAligner as ASRTask
-    else:
-        raise ValueError("Unknown mode: {}".format(mode))
-    parser = ASRTask.get_parser()
-    args = parser.parse_args()
-    return args, ASRTask
-
-
-def build_trainer(modelscope_dict,
-                  data_dir,
-                  output_dir,
-                  train_set="train",
-                  dev_set="validation",
-                  distributed=False,
-                  dataset_type="small",
-                  batch_bins=None,
-                  max_epoch=None,
-                  optim=None,
-                  lr=None,
-                  scheduler=None,
-                  scheduler_conf=None,
-                  specaug=None,
-                  specaug_conf=None,
-                  mate_params=None,
-                  **kwargs):
-    mode = modelscope_dict['mode']
-    args, ASRTask = parse_args(mode=mode)
-    # ddp related
-    if args.local_rank is not None:
-        distributed = True
-    else:
-        distributed = False
-    args.local_rank = args.local_rank if args.local_rank is not None else 0
-    local_rank = args.local_rank
-    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
-        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
-    else:
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
-
-    config = modelscope_dict['am_model_config']
-    finetune_config = modelscope_dict['finetune_config']
-    init_param = modelscope_dict['init_model']
-    cmvn_file = modelscope_dict['cmvn_file']
-    seg_dict_file = modelscope_dict['seg_dict']
-
-    # overwrite parameters
-    with open(config) as f:
-        configs = yaml.safe_load(f)
-    with open(finetune_config) as f:
-        finetune_configs = yaml.safe_load(f)
-        # set data_types
-        if dataset_type == "large":
-            # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
-            if 'data_types' not in finetune_configs['dataset_conf']:
-                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
-    finetune_configs = update_dct(configs, finetune_configs)
-    for key, value in finetune_configs.items():
-        if hasattr(args, key):
-            setattr(args, key, value)
-    if mate_params is not None:
-        for key, value in mate_params.items():
-            if hasattr(args, key):
-                setattr(args, key, value)
-    if mate_params is not None and "lora_params" in mate_params:
-        lora_params = mate_params['lora_params']
-        configs['encoder_conf'].update(lora_params) 
-        configs['decoder_conf'].update(lora_params) 
-
-    # prepare data
-    args.dataset_type = dataset_type
-    if args.dataset_type == "small":
-        args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
-                                                  ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
-        args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
-                                                  ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
-    elif args.dataset_type == "large":
-        args.train_data_file = None
-        args.valid_data_file = None
-    else:
-        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-    args.init_param = [init_param]
-    if mate_params is not None and "init_param" in mate_params:
-        if len(mate_params["init_param"]) != 0:
-            args.init_param = mate_params["init_param"]
-    args.cmvn_file = cmvn_file
-    if os.path.exists(seg_dict_file):
-        args.seg_dict_file = seg_dict_file
-    else:
-        args.seg_dict_file = None
-    args.data_dir = data_dir
-    args.train_set = train_set
-    args.dev_set = dev_set
-    args.output_dir = output_dir
-    args.gpu_id = args.local_rank
-    args.config = finetune_config
-    if optim is not None:
-        args.optim = optim
-    if lr is not None:
-        args.optim_conf["lr"] = lr
-    if scheduler is not None:
-        args.scheduler = scheduler
-    if scheduler_conf is not None:
-        args.scheduler_conf = scheduler_conf
-    if specaug is not None:
-        args.specaug = specaug
-    if specaug_conf is not None:
-        args.specaug_conf = specaug_conf
-    if max_epoch is not None:
-        args.max_epoch = max_epoch
-    if batch_bins is not None:
-        if args.dataset_type == "small":
-            args.batch_bins = batch_bins
-        elif args.dataset_type == "large":
-            args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
-        else:
-            raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-    if args.normalize in ["null", "none", "None"]:
-        args.normalize = None
-    if args.patience in ["null", "none", "None"]:
-        args.patience = None
-    args.local_rank = local_rank
-    args.distributed = distributed
-    ASRTask.finetune_args = args
-
-    return ASRTask

--
Gitblit v1.9.1