From 2191795f742063b1c0a394fc2a65898445ccce65 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 08 六月 2024 19:45:15 +0800
Subject: [PATCH] fix bug

---
 funasr/models/llm_asr/model.py |    7 ++++---
 1 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 697f78d..f8c3efc 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -692,6 +692,7 @@
                 batch_idx, :min_len, :
             ]
 
+        label = contents["assistant"][0]
         if not kwargs.get("tearchforing", False):
 
             generated_ids = self.llm.generate(
@@ -704,7 +705,7 @@
             response = tokenizer.batch_decode(
                 generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
             )[0]
-            label = contents["assistant"][0]
+
             loss = None
         else:
 
@@ -715,13 +716,13 @@
                 inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
             )
 
-            preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
+            preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
             response = tokenizer.batch_decode(
                 preds,
                 add_special_tokens=False,
                 skip_special_tokens=kwargs.get("skip_special_tokens", True),
             )[0]
-            loss = model_outputs.loss
+            loss = model_outputs.loss.item()
 
         ibest_writer = None
         if kwargs.get("output_dir") is not None:

--
Gitblit v1.9.1