| | |
| | | dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} |
| | | with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]): |
| | | label = contents["assistant"][0] |
| | | self.llm = self.llm.to(dtype_map[llm_dtype]) |
| | | inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) |
| | | attention_mask = attention_mask.to(dtype_map[llm_dtype]) |
| | | # self.llm = self.llm.to(dtype_map[llm_dtype]) |
| | | # inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) |
| | | |
| | | if not kwargs.get("tearchforing", False): |
| | | |
| | | generated_ids = self.llm.generate( |
| | |
| | | labels_ids = batch["labels_ids"] |
| | | labels_ids[labels_ids == -1] = -100 |
| | | attention_mask = batch.get("attention_mask", None) |
| | | # attention_mask = attention_mask.to(dtype_map[llm_dtype]) |
| | | model_outputs = self.llm( |
| | | inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids |
| | | ) |