| funasr/models/llm_asr/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/models/llm_asr/model.py
@@ -715,7 +715,9 @@ llm_dtype = kwargs.get("llm_dtype", "fp32") dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]): with torch.cuda.amp.autocast( enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype] ): label = contents["assistant"][0] # self.llm = self.llm.to(dtype_map[llm_dtype]) # inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])