From 6fd83d23ee64280d30fd99f0fe2ca0f52903dca1 Mon Sep 17 00:00:00 2001
From: Marlowe <54339989+ZihanLiao@users.noreply.github.com>
Date: 星期五, 14 六月 2024 10:36:28 +0800
Subject: [PATCH] fix paramter 'quantize' unused issue (#1813)
---
funasr/models/llm_asr/model.py | 68 ++++++++++++++++++++-------------
1 files changed, 41 insertions(+), 27 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 21072b0..519918c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -684,6 +684,13 @@
# audio encoder
speech = batch["speech"]
speech_lengths = batch["speech_lengths"][:, 0]
+ # fp16
+ if kwargs.get("fp16", False):
+ speech = speech.to(torch.float16)
+ encoder_out_lens = encoder_out_lens.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ speech = speech.to(torch.bfloat16)
+ encoder_out_lens = encoder_out_lens.to(torch.bfloat16)
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
# audio_adaptor
@@ -706,37 +713,44 @@
batch_idx, :min_len, :
]
- label = contents["assistant"][0]
- if not kwargs.get("tearchforing", False):
+ llm_dtype = kwargs.get("llm_dtype", "fp32")
+ dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
+ with torch.cuda.amp.autocast(dtype=dtype_map[llm_dtype]):
+ label = contents["assistant"][0]
+ # self.llm = self.llm.to(dtype_map[llm_dtype])
+ # inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
- generated_ids = self.llm.generate(
- inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
- )
- # generated_ids = [
- # output_ids[len(input_id) :]
- # for input_id, output_ids in zip(input_ids, generated_ids)
- # ]
- response = tokenizer.batch_decode(
- generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
- )[0]
+ if not kwargs.get("tearchforing", False):
- loss = None
- else:
+ generated_ids = self.llm.generate(
+ inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
+ )
+ # generated_ids = [
+ # output_ids[len(input_id) :]
+ # for input_id, output_ids in zip(input_ids, generated_ids)
+ # ]
+ response = tokenizer.batch_decode(
+ generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
+ )[0]
- labels_ids = batch["labels_ids"]
- labels_ids[labels_ids == -1] = -100
- attention_mask = batch.get("attention_mask", None)
- model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
- )
+ loss = None
+ else:
- preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
- response = tokenizer.batch_decode(
- preds,
- add_special_tokens=False,
- skip_special_tokens=kwargs.get("skip_special_tokens", True),
- )[0]
- loss = model_outputs.loss.item()
+ labels_ids = batch["labels_ids"]
+ labels_ids[labels_ids == -1] = -100
+ attention_mask = batch.get("attention_mask", None)
+ # attention_mask = attention_mask.to(dtype_map[llm_dtype])
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
+
+ preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
+ response = tokenizer.batch_decode(
+ preds,
+ add_special_tokens=False,
+ skip_special_tokens=kwargs.get("skip_special_tokens", True),
+ )[0]
+ loss = model_outputs.loss.item()
ibest_writer = None
if kwargs.get("output_dir") is not None:
--
Gitblit v1.9.1