From 2191795f742063b1c0a394fc2a65898445ccce65 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 08 六月 2024 19:45:15 +0800
Subject: [PATCH] fix bug
---
funasr/models/llm_asr/model.py | 7 ++++---
1 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 697f78d..f8c3efc 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -692,6 +692,7 @@
batch_idx, :min_len, :
]
+ label = contents["assistant"][0]
if not kwargs.get("tearchforing", False):
generated_ids = self.llm.generate(
@@ -704,7 +705,7 @@
response = tokenizer.batch_decode(
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
)[0]
- label = contents["assistant"][0]
+
loss = None
else:
@@ -715,13 +716,13 @@
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)
- preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
+ preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
response = tokenizer.batch_decode(
preds,
add_special_tokens=False,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
- loss = model_outputs.loss
+ loss = model_outputs.loss.item()
ibest_writer = None
if kwargs.get("output_dir") is not None:
--
Gitblit v1.9.1