From 3d5e19792cd4bb510c2c0fc5749731d52b825c15 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 08 六月 2024 18:43:35 +0800
Subject: [PATCH] fix bug
---
funasr/models/llm_asr/model.py | 28 ++++++++++++++++++++++++----
examples/industrial_data_pretraining/llm_asr/demo_speech2text.py | 4 +++-
2 files changed, 27 insertions(+), 5 deletions(-)
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py
index ed02373..072dcdf 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py
+++ b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.py
@@ -16,12 +16,14 @@
with open(jsonl, "r") as f:
lines = f.readlines()
+tearchforing = True
for i, line in enumerate(lines):
data_dict = json.loads(line.strip())
data = data_dict["messages"]
res = model.generate(
- input=data,
+ input=[data],
+ tearchforing=tearchforing,
cache={},
)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 0955e84..697f78d 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -568,6 +568,7 @@
[],
[],
[],
+ [],
)
for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
@@ -624,7 +625,7 @@
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
- source_ids = torch.tensor(source_ids, dtype=torch.int64)
+ source_ids = torch.tensor(source_ids_i, dtype=torch.int64)
target_ids = torch.tensor(target_ids, dtype=torch.int64)
fbank = speech[0, :, :]
@@ -662,7 +663,7 @@
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
- contents = self.data_template(data_in)
+ contents = self.data_template(data_in[0])
output = self.data_load_speech(contents, tokenizer, frontend, **kwargs)
batch = to_device(output, kwargs["device"])
@@ -676,7 +677,7 @@
input_ids = batch["input_ids"]
source_ids = batch["source_ids"]
- if kwargs.get("tearchforing", False):
+ if not kwargs.get("tearchforing", False):
input_ids = source_ids
input_ids[input_ids < 0] = 0
inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
@@ -704,6 +705,23 @@
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
)[0]
label = contents["assistant"][0]
+ loss = None
+ else:
+
+ labels_ids = batch["labels_ids"]
+ labels_ids[labels_ids == -1] = -100
+ attention_mask = batch.get("attention_mask", None)
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
+
+ preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
+ response = tokenizer.batch_decode(
+ preds,
+ add_special_tokens=False,
+ skip_special_tokens=kwargs.get("skip_special_tokens", True),
+ )[0]
+ loss = model_outputs.loss
ibest_writer = None
if kwargs.get("output_dir") is not None:
@@ -713,10 +731,12 @@
results = []
result_i = {"key": key[0], "text": response, "label": label}
+ if loss is not None:
+ result_i["loss"] = loss
results.append(result_i)
if ibest_writer is not None:
- ibest_writer["text"][key[0]] = text
+ ibest_writer["text"][key[0]] = response
ibest_writer["label"][key[0]] = label
return results, meta_data
--
Gitblit v1.9.1