游雁
2024-06-08 2191795f742063b1c0a394fc2a65898445ccce65
fix bug
1个文件已修改
7 ■■■■■ 已修改文件
funasr/models/llm_asr/model.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py
@@ -692,6 +692,7 @@
                batch_idx, :min_len, :
            ]
        label = contents["assistant"][0]
        if not kwargs.get("tearchforing", False):
            generated_ids = self.llm.generate(
@@ -704,7 +705,7 @@
            response = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
            )[0]
            label = contents["assistant"][0]
            loss = None
        else:
@@ -715,13 +716,13 @@
                inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
            )
            preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
            preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
            response = tokenizer.batch_decode(
                preds,
                add_special_tokens=False,
                skip_special_tokens=kwargs.get("skip_special_tokens", True),
            )[0]
            loss = model_outputs.loss
            loss = model_outputs.loss.item()
        ibest_writer = None
        if kwargs.get("output_dir") is not None: