haoneng.lhn
2023-07-18 16c41542451f399bdb716f1d7cad31cf52f6f8c3
funasr/bin/train.py
@@ -28,6 +28,7 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.modules.lora.utils import mark_only_lora_as_trainable
def get_parser():
@@ -478,6 +479,18 @@
        default=None,
        help="oss bucket.",
    )
    parser.add_argument(
        "--enable_lora",
        type=str2bool,
        default=False,
        help="Apply lora for finetuning.",
    )
    parser.add_argument(
        "--lora_bias",
        type=str,
        default="none",
        help="oss bucket.",
    )
    return parser
@@ -521,6 +534,8 @@
        dtype=getattr(torch, args.train_dtype),
        device="cuda" if args.ngpu > 0 else "cpu",
    )
    if args.enable_lora:
        mark_only_lora_as_trainable(model, args.lora_bias)
    for t in args.freeze_param:
        for k, p in model.named_parameters():
            if k.startswith(t + ".") or k == t: