From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/lora/utils.py | 35 ++++++++++++++++-------------------
1 files changed, 16 insertions(+), 19 deletions(-)
diff --git a/funasr/models/lora/utils.py b/funasr/models/lora/utils.py
index e18bf44..670c1dc 100644
--- a/funasr/models/lora/utils.py
+++ b/funasr/models/lora/utils.py
@@ -10,41 +10,38 @@
from .layers import LoRALayer
-def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
+def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
for n, p in model.named_parameters():
- if 'lora_' not in n and 'cif' not in n:
+ if "lora_" not in n and "cif" not in n:
p.requires_grad = False
- if bias == 'none':
+ if bias == "none":
return
- elif bias == 'all':
+ elif bias == "all":
for n, p in model.named_parameters():
- if 'bias' in n:
+ if "bias" in n:
p.requires_grad = True
- elif bias == 'lora_only':
+ elif bias == "lora_only":
for m in model.modules():
- if isinstance(m, LoRALayer) and \
- hasattr(m, 'bias') and \
- m.bias is not None:
- m.bias.requires_grad = True
+ if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
+ m.bias.requires_grad = True
else:
raise NotImplementedError
-def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
+def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
my_state_dict = model.state_dict()
- if bias == 'none':
- return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
- elif bias == 'all':
- return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
- elif bias == 'lora_only':
+ if bias == "none":
+ return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
+ elif bias == "all":
+ return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
to_return = {}
for k in my_state_dict:
- if 'lora_' in k:
+ if "lora_" in k:
to_return[k] = my_state_dict[k]
- bias_name = k.split('lora_')[0]+'bias'
+ bias_name = k.split("lora_")[0] + "bias"
if bias_name in my_state_dict:
to_return[bias_name] = my_state_dict[bias_name]
return to_return
else:
raise NotImplementedError
-
--
Gitblit v1.9.1