From c553a8db1712c2a5deeef5bbb68bd1fdf8d61ab7 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 17:38:01 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/model.py | 146 ++++++++++++++++++++++++++++++++----------------
1 files changed, 96 insertions(+), 50 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 21072b0..6e7939b 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -21,6 +21,8 @@
from funasr.train_utils.device_funcs import to_device
import traceback
+dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
+
@tables.register("model_classes", "LLMASR")
class LLMASR(nn.Module):
@@ -407,38 +409,60 @@
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
+ freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
+ # if freeze_layer_num > 0:
+ # freeze_layer_num = range(freeze_layer_num)
+
if freeze:
for name, param in audio_encoder.named_parameters():
- param.requires_grad = False
+ if freeze_layer_num > 0:
+ idx = re.search(r"\.\d+\.", name)
+ if idx is not None:
+ beg, end = idx.regs[0]
+ layer_id = int(name[beg + 1 : end - 1])
+ if layer_id < freeze_layer_num:
+ param.requires_grad = False
+ elif not name.startswith("audio_encoder.ln_post"):
+ param.requires_grad = False
+ else:
+ param.requires_grad = False
+
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
- hub = llm_conf.get("hub", "hf")
self.llm = None
- if hub == "hf":
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- model = AutoModelForCausalLM.from_pretrained(
- init_param_path,
- load_in_8bit=None,
- device_map=None,
- use_cache=None,
- )
- freeze = llm_conf.get("freeze", True)
- if freeze:
- for name, param in model.named_parameters():
- param.requires_grad = False
- model.eval()
- self.llm = model
+ init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+
+ model = AutoModelForCausalLM.from_pretrained(
+ init_param_path,
+ load_in_8bit=None,
+ device_map=None,
+ use_cache=None,
+ )
+ freeze = llm_conf.get("freeze", True)
+ if freeze:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model.eval()
+ self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
+ self.llm = model.to(dtype_map[self.llm_dtype])
+ llm_dim = model.get_input_embeddings().weight.shape[-1]
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
+ audio_adaptor_conf["llm_dim"] = llm_dim
audio_adaptor = adaptor_class(**audio_adaptor_conf)
+ init_param_path = audio_adaptor_conf.get("init_param_path", None)
+ if init_param_path is not None:
+ src_state = torch.load(init_param_path, map_location="cpu")
+ flag = audio_adaptor.load_state_dict(src_state, strict=False)
+ logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}")
self.audio_adaptor = audio_adaptor
@@ -506,12 +530,17 @@
batch_idx, :min_len, :
]
- labels_ids[labels_ids == -1] = -100
-
- model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
- )
- loss = model_outputs.loss
+ with torch.cuda.amp.autocast(
+ enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
+ ):
+ labels_ids[labels_ids == -1] = -100
+ attention_mask[attention_mask < 0] = 0
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds.to(dtype_map[self.llm_dtype]),
+ attention_mask=attention_mask,
+ labels=labels_ids,
+ )
+ loss = model_outputs.loss
stats = {}
with torch.no_grad():
@@ -684,6 +713,11 @@
# audio encoder
speech = batch["speech"]
speech_lengths = batch["speech_lengths"][:, 0]
+ # fp16
+ if kwargs.get("fp16", False):
+ speech = speech.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ speech = speech.to(torch.bfloat16)
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
# audio_adaptor
@@ -706,37 +740,49 @@
batch_idx, :min_len, :
]
- label = contents["assistant"][0]
- if not kwargs.get("tearchforing", False):
+ llm_dtype = kwargs.get("llm_dtype", "fp32")
+ if llm_dtype == "fp32":
+ llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
+ llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
- generated_ids = self.llm.generate(
- inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
- )
- # generated_ids = [
- # output_ids[len(input_id) :]
- # for input_id, output_ids in zip(input_ids, generated_ids)
- # ]
- response = tokenizer.batch_decode(
- generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
- )[0]
+ with torch.cuda.amp.autocast(
+ enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
+ ):
+ label = contents["assistant"][0]
+ self.llm = self.llm.to(dtype_map[llm_dtype])
+ inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
- loss = None
- else:
+ if not kwargs.get("tearchforing", False):
- labels_ids = batch["labels_ids"]
- labels_ids[labels_ids == -1] = -100
- attention_mask = batch.get("attention_mask", None)
- model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
- )
+ generated_ids = self.llm.generate(
+ inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
+ )
+ # generated_ids = [
+ # output_ids[len(input_id) :]
+ # for input_id, output_ids in zip(input_ids, generated_ids)
+ # ]
+ response = tokenizer.batch_decode(
+ generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
+ )[0]
- preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
- response = tokenizer.batch_decode(
- preds,
- add_special_tokens=False,
- skip_special_tokens=kwargs.get("skip_special_tokens", True),
- )[0]
- loss = model_outputs.loss.item()
+ loss = None
+ else:
+
+ labels_ids = batch["labels_ids"]
+ labels_ids[labels_ids == -1] = -100
+ attention_mask = batch.get("attention_mask", None)
+ # attention_mask = attention_mask.to(dtype_map[llm_dtype])
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
+
+ preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
+ response = tokenizer.batch_decode(
+ preds,
+ add_special_tokens=False,
+ skip_special_tokens=kwargs.get("skip_special_tokens", True),
+ )[0]
+ loss = model_outputs.loss.item()
ibest_writer = None
if kwargs.get("output_dir") is not None:
--
Gitblit v1.9.1