游雁
2024-06-11 779033386a3680fa4f6236850bc97f135494dcf6
fixbug
1个文件已修改
4 ■■■ 已修改文件
funasr/models/llm_asr/model.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py
@@ -715,7 +715,9 @@
        llm_dtype = kwargs.get("llm_dtype", "fp32")
        dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
        with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]):
        with torch.cuda.amp.autocast(
            enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
        ):
            label = contents["assistant"][0]
            # self.llm = self.llm.to(dtype_map[llm_dtype])
            # inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])