From cce5d9999dabaf257347fbadb7ccc2473c9a757a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 二月 2024 17:19:32 +0800
Subject: [PATCH] add
---
funasr/models/llm_asr/model.py | 341 +++++++++++++++++++++++++++++++++++++
funasr/datasets/llm_datasets/datasets.py | 126 ++++++++++++++
funasr/models/llm_asr/__init__.py | 0
funasr/models/llm_asr/adaptor.py | 62 ++++++
4 files changed, 529 insertions(+), 0 deletions(-)
diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index 9673d76..22151a1 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -129,3 +129,129 @@
outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
return outputs
+
+
+@tables.register("dataset_classes", "AudioLLMARDataset")
+class AudioLLMARDataset(torch.utils.data.Dataset):
+ """
+ AudioLLMDataset
+ """
+
+ def __init__(self,
+ path,
+ index_ds: str = None,
+ frontend=None,
+ tokenizer=None,
+ int_pad_value: int = -1,
+ float_pad_value: float = 0.0,
+ **kwargs):
+ super().__init__()
+ index_ds_class = tables.index_ds_classes.get(index_ds)
+ self.index_ds = index_ds_class(path, **kwargs)
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
+ if preprocessor_speech:
+ preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+ preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+ self.preprocessor_speech = preprocessor_speech
+ preprocessor_text = kwargs.get("preprocessor_text", None)
+ if preprocessor_text:
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+ self.preprocessor_text = preprocessor_text
+
+ self.frontend = frontend
+ self.fs = 16000 if frontend is None else frontend.fs
+ self.data_type = "sound"
+ self.tokenizer = tokenizer
+
+ self.float_pad_value = float_pad_value
+ self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+ self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
+ self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+ self.prompt_af = ""
+ self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+ self.int_pad_value = self.IGNORE_INDEX
+
+ def get_source_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_source_len(item)
+
+ def get_target_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_target_len(item)
+
+ def __len__(self):
+ return len(self.index_ds)
+
+ def __getitem__(self, index):
+ item = self.index_ds[index]
+ # import pdb;
+ # pdb.set_trace()
+ source = item["source"]
+ data_src = load_audio_text_image_video(source, fs=self.fs)
+ if self.preprocessor_speech:
+ data_src = self.preprocessor_speech(data_src, fs=self.fs)
+ speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
+ is_final=True) # speech: [b, T, d]
+ speech = speech.squeeze(0)
+
+ target = item["target"]
+ if self.preprocessor_text:
+ target = self.preprocessor_text(target)
+
+ prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
+ prompt_pre_length = len(prompt_ids_pre)
+
+ prompt_input = "{}{}".format(self.prompt_pre, target)
+ prompt_input_ids = self.tokenizer.encode(prompt_input)
+ audio_length = len(prompt_input_ids) - prompt_pre_length
+ input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
+ input_ids = torch.tensor(input_ids, dtype=torch.int64) # [bos, prompt, input, pad]
+ input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1]
+ attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
+
+ prompt_answer = "{}{}".format(self.prompt_pre, target)
+ prompt_answer_ids = self.tokenizer.encode(prompt_answer)
+ answer_length = len(prompt_answer_ids) - prompt_pre_length
+ labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
+ labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, input, eos]
+ labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos]
+ label_mask = labels_ids.ge(0) # [False,False,True,True]
+ labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,input,eos]
+
+ audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+ audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+
+ ids = self.tokenizer.encode(target) # token ids is different from labels_ids
+ text = torch.tensor(ids, dtype=torch.int64)
+ text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+
+ return {"speech": speech,
+ "speech_lengths": speech_lengths,
+ "text": text,
+ "text_lengths": text_lengths,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels_ids": labels_ids,
+ "label_mask": label_mask,
+ "audio_mask": audio_mask,
+ }
+
+ def collator(self, samples: list = None):
+ outputs = {}
+ for sample in samples:
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ outputs[key].append(sample[key])
+
+ for key, data_list in outputs.items():
+ if isinstance(data_list[0], torch.Tensor):
+ if data_list[0].dtype == torch.int64:
+
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+ return outputs
diff --git a/funasr/models/llm_asr/__init__.py b/funasr/models/llm_asr/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/llm_asr/__init__.py
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
new file mode 100644
index 0000000..2093588
--- /dev/null
+++ b/funasr/models/llm_asr/adaptor.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+
+from funasr.register import tables
+
+@tables.register("adaptor_classes", "Linear")
+class Linear(nn.Module):
+ def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
+ super().__init__()
+ self.k = downsample_rate
+ self.encoder_dim = encoder_dim
+ self.llm_dim = llm_dim
+ self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
+ self.relu = nn.ReLU()
+ self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
+
+ def forward(self, x):
+ batch_size, seq_len, dim = x.size()
+ num_frames_to_discard = seq_len % self.k
+ if num_frames_to_discard > 0:
+ x = x[:, :-num_frames_to_discard, :]
+ seq_len = x.size(1)
+
+ x = x.contiguous()
+ x = x.view(batch_size, seq_len // self.k, dim * self.k)
+ x = self.linear1(x)
+ x = self.relu(x)
+ x = self.linear2(x)
+ return x
+
+@tables.register("adaptor_classes", "QFormer")
+class EncoderProjectorQFormer(nn.Module):
+ def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
+ super().__init__()
+ self.encoder_dim = encoder_dim
+ self.llm_dim = llm_dim
+ from transformers import Blip2QFormerConfig, Blip2QFormerModel
+ configuration = Blip2QFormerConfig()
+ configuration.encoder_hidden_size = self.encoder_dim
+ configuration.num_hidden_layers = 2
+
+ self.query_len = 64
+ self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
+ self.query.data.normal_(mean=0.0, std=1.0)
+ self.qformer = Blip2QFormerModel(configuration)
+
+ self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
+ self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
+
+ def forward(self, x, atts):
+ query = self.query.expand(x.shape[0], -1, -1)
+
+ query_output = self.qformer(
+ query_embeds=query,
+ encoder_hidden_states=x,
+ encoder_attention_mask=atts,
+ return_dict=True,
+ )
+
+ query_proj = self.norm(self.linear(query_output.last_hidden_state))
+
+ return query_proj
\ No newline at end of file
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
new file mode 100644
index 0000000..4139d8c
--- /dev/null
+++ b/funasr/models/llm_asr/model.py
@@ -0,0 +1,341 @@
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+
+from funasr.models.scama.utils import sequence_mask
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
+# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+
+@tables.register("model_classes", "LLMASR")
+class LLMASR(nn.Module):
+ """ """
+
+ def __init__(
+ self,
+ specaug: str = None,
+ specaug_conf: dict = None,
+ normalize: str = None,
+ normalize_conf: dict = None,
+ encoder: str = None,
+ encoder_conf: dict = None,
+ decoder: str = None,
+ decoder_conf: dict = None,
+ ctc: str = None,
+ ctc_conf: dict = None,
+ ctc_weight: float = 0.5,
+ llm: str = None,
+ llm_conf: dict = None,
+ adaptor: str = None,
+ adaptor_conf: dict = None,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+
+ # audio encoder
+ hub = encoder_conf.get("hub", None)
+ if hub == "funasr":
+ from funasr import AutoModel
+ init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+ model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+ # frontend = model.kwargs.get("frontend")
+ model.model.decoder = None
+
+ self.audio_encoder = model.model
+ # self.frontend = frontend
+
+ elif hub == "hf":
+ pass
+ else:
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ # llm
+ hub = llm_conf.get("hub", "hf")
+ self.llm = None
+ if hub == "hf":
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+ init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+ model = AutoModelForCausalLM.from_pretrained(
+ init_param_path,
+ load_in_8bit=None,
+ device_map=None,
+ use_cache=None,
+ )
+ freeze = llm_conf.get("freeze", True)
+ if freeze:
+ for name, param in model.named_parameters():
+ param.requires_grad = False
+ model.eval()
+ self.llm = model
+
+ # adaptor
+ adaptor_class = tables.adaptor_classes.get(adaptor)
+ adaptor = adaptor_class(**adaptor_conf)
+
+ self.adaptor = adaptor
+
+
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ self.error_calculator = None
+
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ input_ids: torch.Tensor,
+ attention_mask:torch.Tensor,
+ labels_ids: torch.Tensor,
+ label_mask: torch.Tensor,
+ audio_mask: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # audio encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+
+ # adaptor
+ encoder_out = self.adaptor(encoder_out)
+
+ if input_ids is not None:
+ input_ids[input_ids == -1] = 0
+ input_ids[input_ids == -100] = 0
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(input_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+
+ if audio_mask is not None:
+ batch_size, token_num, dims = inputs_embeds.shape
+ _, l, _ = encoder_out.shape
+ encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
+ inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
+ inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
+
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
+ loss = model_outputs.loss
+
+
+ stats = {}
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+ stats["acc"] = acc_att
+
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + 1).sum())
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ audio_mask = kwargs.get("audio_mask", None)
+ audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
+
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ enc, enc_lens = self.audio_encoder.encode(**batch)
+ with autocast(False):
+ enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+ mask=enc_mask,
+ target_label_length=audio_token_lengths,
+ )
+
+ return pre_acoustic_embeds, pre_token_length
+
+
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ prompt = kwargs.get("prompt", "Transcribe speech to text.")
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # adaptor
+ encoder_out = self.adaptor(encoder_out)
+
+
+ prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+ prompt_ids = tokenizer.encode(prompt_pre)
+ prompt_length = len(prompt_ids)
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+ inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
+ attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
+
+ # model_outputs = self.llm.generate(
+ # inputs_embeds=inputs_embeds,
+ # max_length=kwargs.get("max_length", 200),
+ # max_new_tokens=kwargs.get("max_new_tokens", 200),
+ # num_beams=kwargs.get("num_beams", 4),
+ # do_sample=kwargs.get("do_sample", False),
+ # min_length=kwargs.get("min_length", 1),
+ # top_p=kwargs.get("top_p", 1.0),
+ # repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+ # length_penalty=kwargs.get("length_penalty", 1.0),
+ # temperature=kwargs.get("temperature", 1.0),
+ # attention_mask=attention_mask,
+ # bos_token_id=tokenizer.bos_token_id,
+ # eos_token_id=tokenizer.eos_token_id,
+ # pad_token_id=tokenizer.pad_token_id
+ # )
+
+
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+ preds = torch.argmax(model_outputs.logits, -1)
+ text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+
+ text = text[0].split(': ')[-1]
+ text = text.strip()
+
+ # preds = torch.argmax(model_outputs.logits, -1)
+
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
+ results = []
+ result_i = {"key": key[0], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["text"][key[0]] = text
+
+
+
+
+ return results, meta_data
+
--
Gitblit v1.9.1