游雁
2024-06-11 23008c7cac4be72d99f2172660c3975bbc54a5ea
fp16
1个文件已修改
60 ■■■■■ 已修改文件
funasr/models/llm_asr/model.py 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py
@@ -706,37 +706,43 @@
                batch_idx, :min_len, :
            ]
        label = contents["assistant"][0]
        if not kwargs.get("tearchforing", False):
        llm_dtype = kwargs.get("llm_dtype", "fp32")
        dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
        with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]):
            label = contents["assistant"][0]
            self.llm = self.llm.to(dtype_map[llm_dtype])
            inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
            attention_mask = attention_mask.to(dtype_map[llm_dtype])
            if not kwargs.get("tearchforing", False):
            generated_ids = self.llm.generate(
                inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
            )
            # generated_ids = [
            #     output_ids[len(input_id) :]
            #     for input_id, output_ids in zip(input_ids, generated_ids)
            # ]
            response = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
            )[0]
                generated_ids = self.llm.generate(
                    inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
                )
                # generated_ids = [
                #     output_ids[len(input_id) :]
                #     for input_id, output_ids in zip(input_ids, generated_ids)
                # ]
                response = tokenizer.batch_decode(
                    generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
                )[0]
            loss = None
        else:
                loss = None
            else:
            labels_ids = batch["labels_ids"]
            labels_ids[labels_ids == -1] = -100
            attention_mask = batch.get("attention_mask", None)
            model_outputs = self.llm(
                inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
            )
                labels_ids = batch["labels_ids"]
                labels_ids[labels_ids == -1] = -100
                attention_mask = batch.get("attention_mask", None)
                model_outputs = self.llm(
                    inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
                )
            preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
            response = tokenizer.batch_decode(
                preds,
                add_special_tokens=False,
                skip_special_tokens=kwargs.get("skip_special_tokens", True),
            )[0]
            loss = model_outputs.loss.item()
                preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
                response = tokenizer.batch_decode(
                    preds,
                    add_special_tokens=False,
                    skip_special_tokens=kwargs.get("skip_special_tokens", True),
                )[0]
                loss = model_outputs.loss.item()
        ibest_writer = None
        if kwargs.get("output_dir") is not None: