From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/lora/utils.py |   35 ++++++++++++++++-------------------
 1 files changed, 16 insertions(+), 19 deletions(-)

diff --git a/funasr/models/lora/utils.py b/funasr/models/lora/utils.py
index e18bf44..670c1dc 100644
--- a/funasr/models/lora/utils.py
+++ b/funasr/models/lora/utils.py
@@ -10,41 +10,38 @@
 from .layers import LoRALayer
 
 
-def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
+def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
     for n, p in model.named_parameters():
-        if 'lora_' not in n and 'cif' not in n:
+        if "lora_" not in n and "cif" not in n:
             p.requires_grad = False
-    if bias == 'none':
+    if bias == "none":
         return
-    elif bias == 'all':
+    elif bias == "all":
         for n, p in model.named_parameters():
-            if 'bias' in n:
+            if "bias" in n:
                 p.requires_grad = True
-    elif bias == 'lora_only':
+    elif bias == "lora_only":
         for m in model.modules():
-            if isinstance(m, LoRALayer) and \
-                hasattr(m, 'bias') and \
-                m.bias is not None:
-                    m.bias.requires_grad = True
+            if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
+                m.bias.requires_grad = True
     else:
         raise NotImplementedError
 
 
-def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
+def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
     my_state_dict = model.state_dict()
-    if bias == 'none':
-        return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
-    elif bias == 'all':
-        return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
-    elif bias == 'lora_only':
+    if bias == "none":
+        return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
+    elif bias == "all":
+        return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k}
+    elif bias == "lora_only":
         to_return = {}
         for k in my_state_dict:
-            if 'lora_' in k:
+            if "lora_" in k:
                 to_return[k] = my_state_dict[k]
-                bias_name = k.split('lora_')[0]+'bias'
+                bias_name = k.split("lora_")[0] + "bias"
                 if bias_name in my_state_dict:
                     to_return[bias_name] = my_state_dict[bias_name]
         return to_return
     else:
         raise NotImplementedError
-

--
Gitblit v1.9.1