From 9afcf0ea7d2877ddbbafec5b1a77f5cf025dab17 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 17:17:03 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/model.py | 8 +++++---
1 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index f5d6a82..2a55cd6 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -410,15 +410,17 @@
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
if freeze_layer_num > 0:
freeze_layer_num = range(freeze_layer_num)
- else:
- freeze_layer_num = [freeze_layer_num]
+
if freeze:
for name, param in audio_encoder.named_parameters():
idx = re.search(r"\.\d+\.", name)
if idx is not None:
beg, end = idx.regs[0]
layer_id = int(name[beg + 1 : end - 1])
- if layer_id in freeze_layer_num:
+ if isinstance(freeze_layer_num, (list, tuple)):
+ if layer_id in freeze_layer_num:
+ param.requires_grad = False
+ else:
param.requires_grad = False
audio_encoder.eval()
--
Gitblit v1.9.1