| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | from funasr.modules.lora.utils import mark_only_lora_as_trainable |
| | | |
| | | |
| | | def get_parser(): |
| | |
| | | parser.add_argument( |
| | | "--init_param", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | nargs="*", |
| | | help="Specify the file path used for initialization of parameters. " |
| | | "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', " |
| | | "where file_path is the model file path, " |
| | |
| | | "--freeze_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | action="append", |
| | | help="Freeze parameters", |
| | | ) |
| | | |
| | |
| | | type=str, |
| | | default="validation", |
| | | help="dev dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--data_file_names", |
| | | type=str, |
| | | default="wav.scp,text", |
| | | help="input data files", |
| | | ) |
| | | parser.add_argument( |
| | | "--speed_perturb", |
| | |
| | | default=None, |
| | | help="oss bucket.", |
| | | ) |
| | | parser.add_argument( |
| | | "--enable_lora", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Apply lora for finetuning.", |
| | | ) |
| | | parser.add_argument( |
| | | "--lora_bias", |
| | | type=str, |
| | | default="none", |
| | | help="lora bias.", |
| | | ) |
| | | |
| | | return parser |
| | | |
| | |
| | | dtype=getattr(torch, args.train_dtype), |
| | | device="cuda" if args.ngpu > 0 else "cpu", |
| | | ) |
| | | if args.enable_lora: |
| | | mark_only_lora_as_trainable(model, args.lora_bias) |
| | | for t in args.freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |
| | | logging.info(f"Setting {k}.requires_grad = False") |
| | | p.requires_grad = False |
| | | |
| | | optimizers = build_optimizer(args, model=model) |
| | | schedulers = build_scheduler(args, optimizers) |
| | | |