游雁
2024-06-12 e4a69d4768674e57faf4a08eecca2fce88d3e190
funasr/models/llm_asr/model.py
@@ -413,15 +413,16 @@
        if freeze:
            for name, param in audio_encoder.named_parameters():
                idx = re.search(r"\.\d+\.", name)
                if idx is not None:
                    beg, end = idx.regs[0]
                    layer_id = int(name[beg + 1 : end - 1])
                    if isinstance(freeze_layer_num, (list, tuple)):
                if isinstance(freeze_layer_num, (list, tuple)):
                    idx = re.search(r"\.\d+\.", name)
                    if idx is not None:
                        beg, end = idx.regs[0]
                        layer_id = int(name[beg + 1 : end - 1])
                        if layer_id in freeze_layer_num:
                            param.requires_grad = False
                    else:
                        param.requires_grad = False
                else:
                    param.requires_grad = False
            audio_encoder.eval()
        self.audio_encoder = audio_encoder