| | |
| | | ] |
| | | |
| | | llm_dtype = kwargs.get("llm_dtype", "fp32") |
| | | if llm_dtype == "fp32": |
| | | llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype |
| | | llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype |
| | | |
| | | dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} |
| | | with torch.cuda.amp.autocast( |
| | | enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] |
| | | ): |
| | | label = contents["assistant"][0] |
| | | # self.llm = self.llm.to(dtype_map[llm_dtype]) |
| | | # inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) |
| | | self.llm = self.llm.to(dtype_map[llm_dtype]) |
| | | inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) |
| | | |
| | | if not kwargs.get("tearchforing", False): |
| | | |