| | |
| | | self.audio_encoder = audio_encoder |
| | | |
| | | # llm |
| | | hub = llm_conf.get("hub", "hf") |
| | | self.llm = None |
| | | if hub == "hf": |
| | | |
| | | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| | | |
| | | init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") |
| | |
| | | param.requires_grad = False |
| | | model.eval() |
| | | self.llm = model |
| | | llm_dim = model.get_input_embeddings().weight.shape[-1] |
| | | |
| | | # adaptor |
| | | adaptor_class = tables.adaptor_classes.get(audio_adaptor) |
| | | audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size |
| | | audio_adaptor_conf["llm_dim"] = llm_dim |
| | | audio_adaptor = adaptor_class(**audio_adaptor_conf) |
| | | init_param_path = audio_adaptor_conf.get("init_param_path", None) |
| | | if init_param_path is not None: |
| | | src_state = torch.load(init_param_path, map_location="cpu") |
| | | flag = audio_adaptor.load_state_dict(src_state, strict=False) |
| | | logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}") |
| | | |
| | | self.audio_adaptor = audio_adaptor |
| | | |