kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/lora/utils.py
@@ -10,41 +10,38 @@
from .layers import LoRALayer
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
    for n, p in model.named_parameters():
        if 'lora_' not in n and 'cif' not in n:
        if "lora_" not in n and "cif" not in n:
            p.requires_grad = False
    if bias == 'none':
    if bias == "none":
        return
    elif bias == 'all':
    elif bias == "all":
        for n, p in model.named_parameters():
            if 'bias' in n:
            if "bias" in n:
                p.requires_grad = True
    elif bias == 'lora_only':
    elif bias == "lora_only":
        for m in model.modules():
            if isinstance(m, LoRALayer) and \
                hasattr(m, 'bias') and \
                m.bias is not None:
                    m.bias.requires_grad = True
            if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
                m.bias.requires_grad = True
    else:
        raise NotImplementedError
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
    my_state_dict = model.state_dict()
    if bias == 'none':
        return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
    elif bias == 'all':
        return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
    elif bias == 'lora_only':
    if bias == "none":
        return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
    elif bias == "all":
        return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        for k in my_state_dict:
            if 'lora_' in k:
            if "lora_" in k:
                to_return[k] = my_state_dict[k]
                bias_name = k.split('lora_')[0]+'bias'
                bias_name = k.split("lora_")[0] + "bias"
                if bias_name in my_state_dict:
                    to_return[bias_name] = my_state_dict[bias_name]
        return to_return
    else:
        raise NotImplementedError