haoneng.lhn
2023-07-19 7a19c526024ce68b0a35f37b30b12037dba82240
add lora finetune code
3个文件已修改
24 ■■■■ 已修改文件
funasr/bin/build_trainer.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/build_trainer.py
@@ -55,7 +55,7 @@
                  scheduler_conf=None,
                  specaug=None,
                  specaug_conf=None,
                  param_dict=None,
                  meta_dict=None,
                  **kwargs):
    mode = modelscope_dict['mode']
    args, ASRTask = parse_args(mode=mode)
@@ -144,8 +144,9 @@
        args.patience = None
    args.local_rank = local_rank
    args.distributed = distributed
    for key, value in kwargs.items():
        args.key = value
    if meta_dict is not None:
        for key, value in meta_dict.items():
            args.key = value
    ASRTask.finetune_args = args
    return ASRTask
funasr/bin/train.py
@@ -489,7 +489,7 @@
        "--lora_bias",
        type=str,
        default="none",
        help="oss bucket.",
        help="lora bias.",
    )
    return parser
funasr/tasks/abs_task.py
@@ -71,6 +71,7 @@
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.modules.lora.utils import mark_only_lora_as_trainable
try:
    import wandb
@@ -952,6 +953,18 @@
            default=None,
            help="oss bucket.",
        )
        group.add_argument(
            "--enable_lora",
            type=str2bool,
            default=False,
            help="Apply lora for finetuning.",
        )
        group.add_argument(
            "--lora_bias",
            type=str,
            default="none",
            help="lora bias.",
        )
        cls.trainer.add_arguments(parser)
        cls.add_task_arguments(parser)
@@ -1246,6 +1259,8 @@
            dtype=getattr(torch, args.train_dtype),
            device="cuda" if args.ngpu > 0 else "cpu",
        )
        if args.enable_lora:
            mark_only_lora_as_trainable(model, args.lora_bias)
        for t in args.freeze_param:
            for k, p in model.named_parameters():
                if k.startswith(t + ".") or k == t: