游雁
2024-06-11 9e13f028bc4c0442b41801cfc346b4465f10d578
fixbug
1个文件已修改
8 ■■■■ 已修改文件
funasr/models/llm_asr/model.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py
@@ -714,13 +714,17 @@
            ]
        llm_dtype = kwargs.get("llm_dtype", "fp32")
        if llm_dtype == "fp32":
            llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
            llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
        dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
        with torch.cuda.amp.autocast(
            enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
        ):
            label = contents["assistant"][0]
            # self.llm = self.llm.to(dtype_map[llm_dtype])
            # inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
            self.llm = self.llm.to(dtype_map[llm_dtype])
            inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
            if not kwargs.get("tearchforing", False):