From d43f77408b8f3e169c59dfb6b6d82e45e6b91714 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 11 六月 2024 19:19:06 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/model.py | 92 +++++++++++++++++++++++++++++++---------------
1 files changed, 62 insertions(+), 30 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 5fde3ff..dd806cf 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -556,7 +556,7 @@
return contents
- def data_load_speech(self, contents: dict, tokenizer, frontend, **kwargs):
+ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
system = contents["system"]
user = contents["user"]
@@ -594,7 +594,10 @@
)
if sub_str.startswith("!"):
try:
+ time1 = time.perf_counter()
data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e:
logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
@@ -604,6 +607,15 @@
frontend=frontend,
is_final=True,
) # speech: [b, T, d]
+
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = (
+ speech_lengths.sum().item()
+ * frontend.frame_shift
+ * frontend.lfr_n
+ / 1000
+ )
if kwargs.get("permute", True):
speech = speech.permute(0, 2, 1)
@@ -666,12 +678,17 @@
raise NotImplementedError("batch decoding is not implemented")
contents = self.data_template(data_in[0])
- output = self.data_load_speech(contents, tokenizer, frontend, **kwargs)
+ output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
batch = to_device(output, kwargs["device"])
# audio encoder
speech = batch["speech"]
speech_lengths = batch["speech_lengths"][:, 0]
+ # fp16
+ if kwargs.get("fp16", False):
+ speech = speech.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ speech = speech.to(torch.bfloat16)
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
# audio_adaptor
@@ -694,37 +711,50 @@
batch_idx, :min_len, :
]
- label = contents["assistant"][0]
- if not kwargs.get("tearchforing", False):
+ llm_dtype = kwargs.get("llm_dtype", "fp32")
+ if llm_dtype == "fp32":
+ llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
+ llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
- generated_ids = self.llm.generate(
- inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
- )
- generated_ids = [
- output_ids[len(input_id) :]
- for input_id, output_ids in zip(input_ids, generated_ids)
- ]
- response = tokenizer.batch_decode(
- generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
- )[0]
+ dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
+ with torch.cuda.amp.autocast(
+ enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
+ ):
+ label = contents["assistant"][0]
+ self.llm = self.llm.to(dtype_map[llm_dtype])
+ inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
- loss = None
- else:
+ if not kwargs.get("tearchforing", False):
- labels_ids = batch["labels_ids"]
- labels_ids[labels_ids == -1] = -100
- attention_mask = batch.get("attention_mask", None)
- model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
- )
+ generated_ids = self.llm.generate(
+ inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
+ )
+ # generated_ids = [
+ # output_ids[len(input_id) :]
+ # for input_id, output_ids in zip(input_ids, generated_ids)
+ # ]
+ response = tokenizer.batch_decode(
+ generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
+ )[0]
- preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
- response = tokenizer.batch_decode(
- preds,
- add_special_tokens=False,
- skip_special_tokens=kwargs.get("skip_special_tokens", True),
- )[0]
- loss = model_outputs.loss.item()
+ loss = None
+ else:
+
+ labels_ids = batch["labels_ids"]
+ labels_ids[labels_ids == -1] = -100
+ attention_mask = batch.get("attention_mask", None)
+ # attention_mask = attention_mask.to(dtype_map[llm_dtype])
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
+
+ preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
+ response = tokenizer.batch_decode(
+ preds,
+ add_special_tokens=False,
+ skip_special_tokens=kwargs.get("skip_special_tokens", True),
+ )[0]
+ loss = model_outputs.loss.item()
ibest_writer = None
if kwargs.get("output_dir") is not None:
@@ -733,7 +763,8 @@
ibest_writer = self.writer[f"{0 + 1}best_recog"]
results = []
- result_i = {"key": key[0], "text": response, "label": label}
+ response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response)
+ result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label}
if loss is not None:
result_i["loss"] = loss
results.append(result_i)
@@ -741,5 +772,6 @@
if ibest_writer is not None:
ibest_writer["text"][key[0]] = response
ibest_writer["label"][key[0]] = label
+ ibest_writer["text_tn"][key[0]] = response_clean
return results, meta_data
--
Gitblit v1.9.1