From 5de8bfdcd8a617ac13c13478505401bbf4e57472 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 13 六月 2024 15:38:17 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/model.py | 360 +++++++++++++++++++++++++++++++++++++++++++----------------
1 files changed, 259 insertions(+), 101 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 82ad134..15969e3 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
-
+import re
from funasr.models.scama.utils import sequence_mask
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.ctc.ctc import CTC
@@ -18,6 +18,10 @@
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+import traceback
+
+dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
@tables.register("model_classes", "LLMASR")
@@ -405,38 +409,60 @@
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
+ freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
+ if freeze_layer_num > 0:
+ freeze_layer_num = range(freeze_layer_num)
+
if freeze:
for name, param in audio_encoder.named_parameters():
- param.requires_grad = False
+ if isinstance(freeze_layer_num, (list, tuple)):
+ idx = re.search(r"\.\d+\.", name)
+ if idx is not None:
+ beg, end = idx.regs[0]
+ layer_id = int(name[beg + 1 : end - 1])
+ if layer_id in freeze_layer_num:
+ param.requires_grad = False
+ else:
+ param.requires_grad = False
+ else:
+ param.requires_grad = False
+
audio_encoder.eval()
self.audio_encoder = audio_encoder
# llm
- hub = llm_conf.get("hub", "hf")
self.llm = None
- if hub == "hf":
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- model = AutoModelForCausalLM.from_pretrained(
- init_param_path,
- load_in_8bit=None,
- device_map=None,
- use_cache=None,
- )
- freeze = llm_conf.get("freeze", True)
- if freeze:
- for name, param in model.named_parameters():
- param.requires_grad = False
- model.eval()
- self.llm = model
+ init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+
+ model = AutoModelForCausalLM.from_pretrained(
+ init_param_path,
+ load_in_8bit=None,
+ device_map=None,
+ use_cache=None,
+ )
+ freeze = llm_conf.get("freeze", True)
+ if freeze:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model.eval()
+ self.llm = model
+ llm_dim = model.get_input_embeddings().weight.shape[-1]
+ self.llm_dtype = llm_conf.get("llm_dtype", "fp32")
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
+ audio_adaptor_conf["llm_dim"] = llm_dim
audio_adaptor = adaptor_class(**audio_adaptor_conf)
+ init_param_path = audio_adaptor_conf.get("init_param_path", None)
+ if init_param_path is not None:
+ src_state = torch.load(init_param_path, map_location="cpu")
+ flag = audio_adaptor.load_state_dict(src_state, strict=False)
+ logging.info(f"Loading audio_adaptor ckpt: {init_param_path}, status: {flag}")
self.audio_adaptor = audio_adaptor
@@ -488,8 +514,7 @@
fbank_fake_len = fbank_fake_lens[batch_idx].item()
fbank_beg_idx = fbank_beg[batch_idx, 0].item()
min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
- fbank_fake_len = encoder_out_lens[batch_idx].item()
- min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
+
try:
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
batch_idx, :min_len, :
@@ -497,19 +522,23 @@
except Exception as e:
logging.error(f"{str(e)}, {traceback.format_exc()}")
logging.info(
- f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}"
+ f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[batch_idx].item()}"
)
fbank_fake_len = encoder_out_lens[batch_idx].item()
- min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
+ min_len = min(fbank_fake_len, min_len)
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
batch_idx, :min_len, :
]
- labels_ids[labels_ids == -1] = -100
- model_outputs = self.llm(
- inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
- )
- loss = model_outputs.loss
+ with torch.cuda.amp.autocast(
+ enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
+ ):
+ labels_ids[labels_ids == -1] = -100
+ attention_mask[attention_mask < 0] = 0
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
+ loss = model_outputs.loss
stats = {}
with torch.no_grad():
@@ -532,6 +561,133 @@
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
+ def data_template(self, data):
+ system, user, assistant = [], [], []
+ for i, item in enumerate(data):
+ role = item["role"]
+ content = item["content"]
+ if role == "system":
+ system.append(content)
+ elif role == "user":
+ user.append(content)
+ elif role == "assistant":
+ assistant.append(content)
+
+ system = system * len(user)
+
+ contents = {
+ "system": system,
+ "user": user,
+ "assistant": assistant,
+ }
+
+ return contents
+
+ def data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
+
+ system = contents["system"]
+ user = contents["user"]
+ assistant = contents["assistant"]
+ pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
+ input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = (
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ )
+
+ for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
+
+ source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
+
+ splits = pattern.split(source_input)
+ source_ids_i = []
+ fbank_mask_i = []
+ fbank_beg_i = []
+ fbank_lens_i = []
+ # target_ids_i = []
+ for k, sub_str in enumerate(splits):
+ if not sub_str.startswith("<|startofspeech|>"):
+ sub_token = tokenizer.encode(sub_str)
+ source_ids_i += sub_token
+ fbank_mask_i += [0] * len(sub_token)
+ else:
+ sub_str = sub_str.replace("<|startofspeech|>", "").replace(
+ "<|endofspeech|>", ""
+ )
+ if sub_str.startswith("!"):
+ try:
+ time1 = time.perf_counter()
+ data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ except Exception as e:
+ logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
+
+ speech, speech_lengths = extract_fbank(
+ data_src,
+ data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend,
+ is_final=True,
+ ) # speech: [b, T, d]
+
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = (
+ speech_lengths.sum().item()
+ * frontend.frame_shift
+ * frontend.lfr_n
+ / 1000
+ )
+
+ if kwargs.get("permute", True):
+ speech = speech.permute(0, 2, 1)
+
+ olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
+ olens = 1 + (olens - 3 + 2 * 1) // 2
+ sub_token_len = (olens - 1) // 2 + 1
+ sub_token = [0] * sub_token_len
+ fbank_beg_i = [len(source_ids_i)]
+ source_ids_i += sub_token
+ fbank_mask_i += [1] * len(sub_token)
+
+ source_mask = [-100] * len(source_ids_i)
+ target_out = f"{target_out}<|im_end|>"
+ target_ids = tokenizer.encode(target_out)
+ input_ids += source_ids_i + target_ids
+ labels += source_mask + target_ids
+ fbank_mask += fbank_mask_i
+ fbank_beg.append(fbank_beg_i)
+
+ input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
+ attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
+ labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
+ source_ids = torch.tensor(source_ids_i, dtype=torch.int64)
+ target_ids = torch.tensor(target_ids, dtype=torch.int64)
+
+ fbank = speech[0, :, :]
+ fbank_lens = speech_lengths
+ fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
+ fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
+
+ output = {
+ "speech": fbank[None, :, :],
+ "speech_lengths": fbank_lens[:, None],
+ "fbank_mask": fbank_mask[None, :],
+ "fbank_beg": fbank_beg[None,],
+ "input_ids": input_ids[None, :],
+ "attention_mask": attention_mask[None, :],
+ "labels_ids": labels[None, :],
+ "source_ids": source_ids[None, :],
+ "target_ids": target_ids[None, :],
+ }
+
+ return output
+
def inference(
self,
data_in,
@@ -542,92 +698,89 @@
**kwargs,
):
- prompt = kwargs.get("prompt", "Transcribe speech to text.")
+ meta_data = {}
+ prompt = kwargs.get("prompt", None)
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
- meta_data = {}
- if (
- isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
- ): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(
- data_in,
- fs=frontend.fs,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer,
- )
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(
- audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
- )
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = (
- speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
- )
+ contents = self.data_template(data_in[0])
+ output = self.data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
+ batch = to_device(output, kwargs["device"])
- speech = speech.to(device=kwargs["device"])
- speech_lengths = speech_lengths.to(device=kwargs["device"])
+ # audio encoder
+ speech = batch["speech"]
+ speech_lengths = batch["speech_lengths"][:, 0]
+ # fp16
+ if kwargs.get("fp16", False):
+ speech = speech.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ speech = speech.to(torch.bfloat16)
+ encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ # audio_adaptor
+ encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
- # adaptor
- encoder_out = self.audio_adaptor(encoder_out)
+ input_ids = batch["input_ids"]
+ source_ids = batch["source_ids"]
+ if not kwargs.get("tearchforing", False):
+ input_ids = source_ids
+ input_ids[input_ids < 0] = 0
+ inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
- prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
- prompt_ids = tokenizer.encode(prompt_pre)
- prompt_length = len(prompt_ids)
- prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+ batch_size, token_num, dims = inputs_embeds.shape
+ fbank_beg = batch["fbank_beg"]
+ for batch_idx in range(batch_size):
- if hasattr(self.llm.model, "embed_tokens"):
- inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
- elif hasattr(self.llm.model.model, "embed_tokens"):
- inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
- else:
- inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+ min_len = encoder_out_lens[batch_idx].item()
+ fbank_beg_idx = fbank_beg[batch_idx]
+ inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
+ batch_idx, :min_len, :
+ ]
- inputs_embeds = torch.cat(
- (inputs_embeds[None, :, :], encoder_out), dim=1
- ) # [prompt, audio]
- attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
- kwargs["device"]
- )
+ llm_dtype = kwargs.get("llm_dtype", "fp32")
+ if llm_dtype == "fp32":
+ llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
+ llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
- preds = self.llm.generate(
- inputs_embeds=inputs_embeds,
- max_length=kwargs.get("max_length", 200),
- max_new_tokens=kwargs.get("max_new_tokens", 200),
- num_beams=kwargs.get("num_beams", 4),
- do_sample=kwargs.get("do_sample", False),
- min_length=kwargs.get("min_length", 1),
- top_p=kwargs.get("top_p", 1.0),
- repetition_penalty=kwargs.get("repetition_penalty", 1.0),
- length_penalty=kwargs.get("length_penalty", 1.0),
- temperature=kwargs.get("temperature", 1.0),
- attention_mask=attention_mask,
- bos_token_id=tokenizer.bos_token_id,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id,
- )
+ with torch.cuda.amp.autocast(
+ enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
+ ):
+ label = contents["assistant"][0]
+ self.llm = self.llm.to(dtype_map[llm_dtype])
+ inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
- text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+ if not kwargs.get("tearchforing", False):
- text = text[0].split(": ")[-1]
- text = text.strip()
+ generated_ids = self.llm.generate(
+ inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
+ )
+ # generated_ids = [
+ # output_ids[len(input_id) :]
+ # for input_id, output_ids in zip(input_ids, generated_ids)
+ # ]
+ response = tokenizer.batch_decode(
+ generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
+ )[0]
- # preds = torch.argmax(model_outputs.logits, -1)
+ loss = None
+ else:
+
+ labels_ids = batch["labels_ids"]
+ labels_ids[labels_ids == -1] = -100
+ attention_mask = batch.get("attention_mask", None)
+ # attention_mask = attention_mask.to(dtype_map[llm_dtype])
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
+
+ preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
+ response = tokenizer.batch_decode(
+ preds,
+ add_special_tokens=False,
+ skip_special_tokens=kwargs.get("skip_special_tokens", True),
+ )[0]
+ loss = model_outputs.loss.item()
ibest_writer = None
if kwargs.get("output_dir") is not None:
@@ -636,10 +789,15 @@
ibest_writer = self.writer[f"{0 + 1}best_recog"]
results = []
- result_i = {"key": key[0], "text": text}
+ response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response)
+ result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label}
+ if loss is not None:
+ result_i["loss"] = loss
results.append(result_i)
if ibest_writer is not None:
- ibest_writer["text"][key[0]] = text
+ ibest_writer["text"][key[0]] = response
+ ibest_writer["label"][key[0]] = label
+ ibest_writer["text_tn"][key[0]] = response_clean
return results, meta_data
--
Gitblit v1.9.1