游雁
2024-06-11 23008c7cac4be72d99f2172660c3975bbc54a5ea
fp16
1个文件已修改
6 ■■■■■ 已修改文件
funasr/models/llm_asr/model.py 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py
@@ -706,7 +706,13 @@
                batch_idx, :min_len, :
            ]
        llm_dtype = kwargs.get("llm_dtype", "fp32")
        dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
        with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]):
        label = contents["assistant"][0]
            self.llm = self.llm.to(dtype_map[llm_dtype])
            inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
            attention_mask = attention_mask.to(dtype_map[llm_dtype])
        if not kwargs.get("tearchforing", False):
            generated_ids = self.llm.generate(