haoneng.lhn
2023-07-19 7a19c526024ce68b0a35f37b30b12037dba82240
funasr/tasks/abs_task.py
@@ -71,6 +71,7 @@
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.modules.lora.utils import mark_only_lora_as_trainable
try:
    import wandb
@@ -952,6 +953,18 @@
            default=None,
            help="oss bucket.",
        )
        group.add_argument(
            "--enable_lora",
            type=str2bool,
            default=False,
            help="Apply lora for finetuning.",
        )
        group.add_argument(
            "--lora_bias",
            type=str,
            default="none",
            help="lora bias.",
        )
        cls.trainer.add_arguments(parser)
        cls.add_task_arguments(parser)
@@ -1246,6 +1259,8 @@
            dtype=getattr(torch, args.train_dtype),
            device="cuda" if args.ngpu > 0 else "cpu",
        )
        if args.enable_lora:
            mark_only_lora_as_trainable(model, args.lora_bias)
        for t in args.freeze_param:
            for k, p in model.named_parameters():
                if k.startswith(t + ".") or k == t: