| | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | import traceback |
| | | |
| | | dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} |
| | | |
| | | |
| | | @tables.register("model_classes", "LLMASR") |
| | | class LLMASR(nn.Module): |
| | |
| | | model.eval() |
| | | self.llm = model |
| | | llm_dim = model.get_input_embeddings().weight.shape[-1] |
| | | self.llm_dtype = llm_conf.get("llm_dtype", "fp32") |
| | | |
| | | # adaptor |
| | | adaptor_class = tables.adaptor_classes.get(audio_adaptor) |
| | |
| | | batch_idx, :min_len, : |
| | | ] |
| | | |
| | | with torch.cuda.amp.autocast( |
| | | enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype] |
| | | ): |
| | | labels_ids[labels_ids == -1] = -100 |
| | | attention_mask[attention_mask < 0] = 0 |
| | | model_outputs = self.llm( |
| | |
| | | llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype |
| | | llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype |
| | | |
| | | dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} |
| | | with torch.cuda.amp.autocast( |
| | | enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] |
| | | ): |