From 6ca0b838d48106030984eacf204e8f1f2f05985b Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 16:07:49 +0800
Subject: [PATCH] decoding

---
 funasr/models/llm_asr/model.py |    8 ++++----
 1 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 15969e3..85351b7 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -410,17 +410,17 @@
             audio_encoder_output_size = audio_encoder.output_size()
         freeze = audio_encoder_conf.get("freeze", True)
         freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
-        if freeze_layer_num > 0:
-            freeze_layer_num = range(freeze_layer_num)
+        # if freeze_layer_num > 0:
+        #     freeze_layer_num = range(freeze_layer_num)
 
         if freeze:
             for name, param in audio_encoder.named_parameters():
-                if isinstance(freeze_layer_num, (list, tuple)):
+                if freeze_layer_num > 0:
                     idx = re.search(r"\.\d+\.", name)
                     if idx is not None:
                         beg, end = idx.regs[0]
                         layer_id = int(name[beg + 1 : end - 1])
-                        if layer_id in freeze_layer_num:
+                        if layer_id < freeze_layer_num:
                             param.requires_grad = False
                     else:
                         param.requires_grad = False

--
Gitblit v1.9.1