| funasr/models/llm_asr/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/models/llm_asr/model.py
@@ -706,7 +706,13 @@ batch_idx, :min_len, : ] llm_dtype = kwargs.get("llm_dtype", "fp32") dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]): label = contents["assistant"][0] self.llm = self.llm.to(dtype_map[llm_dtype]) inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype]) attention_mask = attention_mask.to(dtype_map[llm_dtype]) if not kwargs.get("tearchforing", False): generated_ids = self.llm.generate(