From 9afcf0ea7d2877ddbbafec5b1a77f5cf025dab17 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 17:17:03 +0800
Subject: [PATCH] decoding

---
 funasr/models/llm_asr/model.py |    8 +++++---
 1 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index f5d6a82..2a55cd6 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -410,15 +410,17 @@
         freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
         if freeze_layer_num > 0:
             freeze_layer_num = range(freeze_layer_num)
-        else:
-            freeze_layer_num = [freeze_layer_num]
+
         if freeze:
             for name, param in audio_encoder.named_parameters():
                 idx = re.search(r"\.\d+\.", name)
                 if idx is not None:
                     beg, end = idx.regs[0]
                     layer_id = int(name[beg + 1 : end - 1])
-                    if layer_id in freeze_layer_num:
+                    if isinstance(freeze_layer_num, (list, tuple)):
+                        if layer_id in freeze_layer_num:
+                            param.requires_grad = False
+                    else:
                         param.requires_grad = False
             audio_encoder.eval()
 

--
Gitblit v1.9.1