From 7d57828086d8c83707478bf45a29b68482962e6e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 14:29:28 +0800
Subject: [PATCH] decoding

---
 funasr/models/llm_asr/model.py |   12 +++++++++++-
 1 files changed, 11 insertions(+), 1 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index dd806cf..e7dd11e 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -407,9 +407,19 @@
             audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
             audio_encoder_output_size = audio_encoder.output_size()
         freeze = audio_encoder_conf.get("freeze", True)
+        freeze_layer_num = audio_encoder_conf.get("freeze_layer_num", -1)
+        if freeze_layer_num > 0:
+            freeze_layer_num = range(freeze_layer_num)
+        else:
+            freeze_layer_num = [freeze_layer_num]
         if freeze:
             for name, param in audio_encoder.named_parameters():
-                param.requires_grad = False
+                idx = re.search(r"\.\d+\.", name)
+                if idx is not None:
+                    beg, end = idx.regs[0]
+                    layer_id = int(name[beg + 1 : end - 1])
+                    if layer_id in freeze_layer_num:
+                        param.requires_grad = False
             audio_encoder.eval()
 
         self.audio_encoder = audio_encoder

--
Gitblit v1.9.1