游雁
2024-06-12 2ac79cd3f312e485f3fc4f0e63313cc8a3e0bfc6
funasr/models/llm_asr/model.py
@@ -413,44 +413,51 @@
        if freeze:
            for name, param in audio_encoder.named_parameters():
                idx = re.search(r"\.\d+\.", name)
                if idx is not None:
                    beg, end = idx.regs[0]
                    layer_id = int(name[beg + 1 : end - 1])
                    if isinstance(freeze_layer_num, (list, tuple)):
                if isinstance(freeze_layer_num, (list, tuple)):
                    idx = re.search(r"\.\d+\.", name)
                    if idx is not None:
                        beg, end = idx.regs[0]
                        layer_id = int(name[beg + 1 : end - 1])
                        if layer_id in freeze_layer_num:
                            param.requires_grad = False
                    else:
                        param.requires_grad = False
                else:
                    param.requires_grad = False
            audio_encoder.eval()
        self.audio_encoder = audio_encoder
        # llm
        hub = llm_conf.get("hub", "hf")
        self.llm = None
        if hub == "hf":
            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
        from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
            model = AutoModelForCausalLM.from_pretrained(
                init_param_path,
                load_in_8bit=None,
                device_map=None,
                use_cache=None,
            )
            freeze = llm_conf.get("freeze", True)
            if freeze:
                for name, param in model.named_parameters():
                    param.requires_grad = False
                model.eval()
            self.llm = model
        init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
        model = AutoModelForCausalLM.from_pretrained(
            init_param_path,
            load_in_8bit=None,
            device_map=None,
            use_cache=None,
        )
        freeze = llm_conf.get("freeze", True)
        if freeze:
            for name, param in model.named_parameters():
                param.requires_grad = False
            model.eval()
        self.llm = model
        llm_dim = model.get_input_embeddings().weight.shape[-1]
        # adaptor
        adaptor_class = tables.adaptor_classes.get(audio_adaptor)
        audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
        audio_adaptor_conf["llm_dim"] = llm_dim
        audio_adaptor = adaptor_class(**audio_adaptor_conf)
        init_param_path = audio_adaptor_conf.get("init_param_path", None)
        if init_param_path is not None:
            src_state = torch.load(init_param_path, map_location="cpu")
            flag = audio_adaptor.load_state_dict(src_state, strict=False)
            logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}")
        self.audio_adaptor = audio_adaptor