From 2a8d041806df41fa3719505d1b3379bbbd369574 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 08 六月 2024 21:35:21 +0800
Subject: [PATCH] fix bug

---
 funasr/models/llm_asr/model.py |   37 ++++++++++++++++++++++++++++++-------
 1 files changed, 30 insertions(+), 7 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 0955e84..5fde3ff 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -19,6 +19,7 @@
 from funasr.utils.datadir_writer import DatadirWriter
 from funasr.register import tables
 from funasr.train_utils.device_funcs import to_device
+import traceback
 
 
 @tables.register("model_classes", "LLMASR")
@@ -489,6 +490,7 @@
             fbank_fake_len = fbank_fake_lens[batch_idx].item()
             fbank_beg_idx = fbank_beg[batch_idx, 0].item()
             min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
+
             try:
                 inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
                     batch_idx, :min_len, :
@@ -496,10 +498,10 @@
             except Exception as e:
                 logging.error(f"{str(e)}, {traceback.format_exc()}")
                 logging.info(
-                    f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}"
+                    f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[batch_idx].item()}"
                 )
                 fbank_fake_len = encoder_out_lens[batch_idx].item()
-                min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
+                min_len = min(fbank_fake_len, min_len)
                 inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
                     batch_idx, :min_len, :
                 ]
@@ -568,6 +570,7 @@
             [],
             [],
             [],
+            [],
         )
 
         for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
@@ -624,7 +627,7 @@
         input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
         attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
         labels = torch.tensor(labels, dtype=torch.int64)  # [: self.max_token_length]
-        source_ids = torch.tensor(source_ids, dtype=torch.int64)
+        source_ids = torch.tensor(source_ids_i, dtype=torch.int64)
         target_ids = torch.tensor(target_ids, dtype=torch.int64)
 
         fbank = speech[0, :, :]
@@ -662,7 +665,7 @@
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
 
-        contents = self.data_template(data_in)
+        contents = self.data_template(data_in[0])
         output = self.data_load_speech(contents, tokenizer, frontend, **kwargs)
         batch = to_device(output, kwargs["device"])
 
@@ -676,7 +679,7 @@
 
         input_ids = batch["input_ids"]
         source_ids = batch["source_ids"]
-        if kwargs.get("tearchforing", False):
+        if not kwargs.get("tearchforing", False):
             input_ids = source_ids
         input_ids[input_ids < 0] = 0
         inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
@@ -691,6 +694,7 @@
                 batch_idx, :min_len, :
             ]
 
+        label = contents["assistant"][0]
         if not kwargs.get("tearchforing", False):
 
             generated_ids = self.llm.generate(
@@ -703,7 +707,24 @@
             response = tokenizer.batch_decode(
                 generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
             )[0]
-            label = contents["assistant"][0]
+
+            loss = None
+        else:
+
+            labels_ids = batch["labels_ids"]
+            labels_ids[labels_ids == -1] = -100
+            attention_mask = batch.get("attention_mask", None)
+            model_outputs = self.llm(
+                inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+            )
+
+            preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
+            response = tokenizer.batch_decode(
+                preds,
+                add_special_tokens=False,
+                skip_special_tokens=kwargs.get("skip_special_tokens", True),
+            )[0]
+            loss = model_outputs.loss.item()
 
         ibest_writer = None
         if kwargs.get("output_dir") is not None:
@@ -713,10 +734,12 @@
 
         results = []
         result_i = {"key": key[0], "text": response, "label": label}
+        if loss is not None:
+            result_i["loss"] = loss
         results.append(result_i)
 
         if ibest_writer is not None:
-            ibest_writer["text"][key[0]] = text
+            ibest_writer["text"][key[0]] = response
             ibest_writer["label"][key[0]] = label
 
         return results, meta_data

--
Gitblit v1.9.1