| | |
| | | scheduler_conf=None, |
| | | specaug=None, |
| | | specaug_conf=None, |
| | | param_dict=None, |
| | | meta_dict=None, |
| | | **kwargs): |
| | | mode = modelscope_dict['mode'] |
| | | args, ASRTask = parse_args(mode=mode) |
| | |
| | | args.patience = None |
| | | args.local_rank = local_rank |
| | | args.distributed = distributed |
| | | for key, value in kwargs.items(): |
| | | args.key = value |
| | | if meta_dict is not None: |
| | | for key, value in meta_dict.items(): |
| | | args.key = value |
| | | ASRTask.finetune_args = args |
| | | |
| | | return ASRTask |
| | |
| | | "--lora_bias", |
| | | type=str, |
| | | default="none", |
| | | help="oss bucket.", |
| | | help="lora bias.", |
| | | ) |
| | | |
| | | return parser |
| | |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | from funasr.modules.lora.utils import mark_only_lora_as_trainable |
| | | |
| | | try: |
| | | import wandb |
| | |
| | | default=None, |
| | | help="oss bucket.", |
| | | ) |
| | | group.add_argument( |
| | | "--enable_lora", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Apply lora for finetuning.", |
| | | ) |
| | | group.add_argument( |
| | | "--lora_bias", |
| | | type=str, |
| | | default="none", |
| | | help="lora bias.", |
| | | ) |
| | | |
| | | cls.trainer.add_arguments(parser) |
| | | cls.add_task_arguments(parser) |
| | |
| | | dtype=getattr(torch, args.train_dtype), |
| | | device="cuda" if args.ngpu > 0 else "cpu", |
| | | ) |
| | | if args.enable_lora: |
| | | mark_only_lora_as_trainable(model, args.lora_bias) |
| | | for t in args.freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |