From cce5d9999dabaf257347fbadb7ccc2473c9a757a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 二月 2024 17:19:32 +0800
Subject: [PATCH] add

---
 funasr/models/llm_asr/model.py           |  341 +++++++++++++++++++++++++++++++++++++
 funasr/datasets/llm_datasets/datasets.py |  126 ++++++++++++++
 funasr/models/llm_asr/__init__.py        |    0 
 funasr/models/llm_asr/adaptor.py         |   62 ++++++
 4 files changed, 529 insertions(+), 0 deletions(-)

diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index 9673d76..22151a1 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -129,3 +129,129 @@
                 
                 outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
         return outputs
+
+
+@tables.register("dataset_classes", "AudioLLMARDataset")
+class AudioLLMARDataset(torch.utils.data.Dataset):
+    """
+    AudioLLMDataset
+    """
+    
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                 **kwargs):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+        preprocessor_speech = kwargs.get("preprocessor_speech", None)
+        if preprocessor_speech:
+            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+        self.preprocessor_speech = preprocessor_speech
+        preprocessor_text = kwargs.get("preprocessor_text", None)
+        if preprocessor_text:
+            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+        self.preprocessor_text = preprocessor_text
+        
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        self.tokenizer = tokenizer
+        
+        self.float_pad_value = float_pad_value
+        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
+            self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+        self.prompt_af = ""
+        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+        self.int_pad_value = self.IGNORE_INDEX
+    
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+    
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+    
+    def __len__(self):
+        return len(self.index_ds)
+    
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio_text_image_video(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
+        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
+                                               is_final=True)  # speech: [b, T, d]
+        speech = speech.squeeze(0)
+        
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+        
+        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre)  # [bos,prompt]
+        prompt_pre_length = len(prompt_ids_pre)
+        
+        prompt_input = "{}{}".format(self.prompt_pre, target)
+        prompt_input_ids = self.tokenizer.encode(prompt_input)
+        audio_length = len(prompt_input_ids) - prompt_pre_length
+        input_ids = prompt_input_ids + [self.tokenizer.pad_token_id]
+        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [bos, prompt, input, pad]
+        input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
+        attention_mask = input_ids.ge(-1)  # [true, true, true, true], length mask
+        
+        prompt_answer = "{}{}".format(self.prompt_pre, target)
+        prompt_answer_ids = self.tokenizer.encode(prompt_answer)
+        answer_length = len(prompt_answer_ids) - prompt_pre_length
+        labels_ids = copy.deepcopy(prompt_input_ids) + [self.tokenizer.eos_token_id]
+        labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, input, eos]
+        labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
+        label_mask = labels_ids.ge(0)  # [False,False,True,True]
+        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
+        
+        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+        
+        ids = self.tokenizer.encode(target)  # token ids is different from labels_ids
+        text = torch.tensor(ids, dtype=torch.int64)
+        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+        
+        return {"speech": speech,
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "input_ids": input_ids,
+                "attention_mask": attention_mask,
+                "labels_ids": labels_ids,
+                "label_mask": label_mask,
+                "audio_mask": audio_mask,
+                }
+    
+    def collator(self, samples: list = None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+        
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64:
+                    
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+                
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+        return outputs
diff --git a/funasr/models/llm_asr/__init__.py b/funasr/models/llm_asr/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/llm_asr/__init__.py
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
new file mode 100644
index 0000000..2093588
--- /dev/null
+++ b/funasr/models/llm_asr/adaptor.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+
+from funasr.register import tables
+
+@tables.register("adaptor_classes", "Linear")
+class Linear(nn.Module):
+    def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
+        super().__init__()
+        self.k = downsample_rate
+        self.encoder_dim = encoder_dim
+        self.llm_dim = llm_dim
+        self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
+        self.relu = nn.ReLU()
+        self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
+
+    def forward(self, x):
+        batch_size, seq_len, dim = x.size()
+        num_frames_to_discard = seq_len % self.k
+        if num_frames_to_discard > 0:
+            x = x[:, :-num_frames_to_discard, :]
+        seq_len = x.size(1)
+        
+        x = x.contiguous()
+        x = x.view(batch_size, seq_len // self.k, dim * self.k)
+        x = self.linear1(x)
+        x = self.relu(x)
+        x = self.linear2(x)
+        return x
+
+@tables.register("adaptor_classes", "QFormer")
+class EncoderProjectorQFormer(nn.Module):
+    def __init__(self, downsample_rate, encoder_dim, llm_dim, ffn_dim: int = 2048, **kwargs):
+        super().__init__()
+        self.encoder_dim = encoder_dim
+        self.llm_dim = llm_dim
+        from transformers import Blip2QFormerConfig, Blip2QFormerModel
+        configuration = Blip2QFormerConfig()
+        configuration.encoder_hidden_size = self.encoder_dim
+        configuration.num_hidden_layers = 2
+        
+        self.query_len = 64
+        self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
+        self.query.data.normal_(mean=0.0, std=1.0)
+        self.qformer = Blip2QFormerModel(configuration)
+        
+        self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
+        self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
+    
+    def forward(self, x, atts):
+        query = self.query.expand(x.shape[0], -1, -1)
+        
+        query_output = self.qformer(
+            query_embeds=query,
+            encoder_hidden_states=x,
+            encoder_attention_mask=atts,
+            return_dict=True,
+        )
+        
+        query_proj = self.norm(self.linear(query_output.last_hidden_state))
+        
+        return query_proj
\ No newline at end of file
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
new file mode 100644
index 0000000..4139d8c
--- /dev/null
+++ b/funasr/models/llm_asr/model.py
@@ -0,0 +1,341 @@
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+
+from funasr.models.scama.utils import sequence_mask
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
+# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+
+@tables.register("model_classes", "LLMASR")
+class LLMASR(nn.Module):
+    """ """
+    
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        decoder: str = None,
+        decoder_conf: dict = None,
+        ctc: str = None,
+        ctc_conf: dict = None,
+        ctc_weight: float = 0.5,
+        llm: str = None,
+        llm_conf: dict = None,
+        adaptor: str = None,
+        adaptor_conf: dict = None,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        **kwargs,
+    ):
+        
+        super().__init__()
+        
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        
+        # audio encoder
+        hub = encoder_conf.get("hub", None)
+        if hub == "funasr":
+            from funasr import AutoModel
+            init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+            model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+            # frontend = model.kwargs.get("frontend")
+            model.model.decoder = None
+            
+            self.audio_encoder = model.model
+            # self.frontend = frontend
+            
+        elif hub == "hf":
+            pass
+        else:
+            encoder_class = tables.encoder_classes.get(encoder)
+            encoder = encoder_class(input_size=input_size, **encoder_conf)
+            encoder_output_size = encoder.output_size()
+
+        # llm
+        hub = llm_conf.get("hub", "hf")
+        self.llm = None
+        if hub == "hf":
+            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+            model = AutoModelForCausalLM.from_pretrained(
+                init_param_path,
+                load_in_8bit=None,
+                device_map=None,
+                use_cache=None,
+            )
+            freeze = llm_conf.get("freeze", True)
+            if freeze:
+                for name, param in model.named_parameters():
+                    param.requires_grad = False
+                model.eval()
+            self.llm = model
+        
+        # adaptor
+        adaptor_class = tables.adaptor_classes.get(adaptor)
+        adaptor = adaptor_class(**adaptor_conf)
+        
+        self.adaptor = adaptor
+        
+        
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+        #
+        # if report_cer or report_wer:
+        #     self.error_calculator = ErrorCalculator(
+        #         token_list, sym_space, sym_blank, report_cer, report_wer
+        #     )
+        #
+        self.error_calculator = None
+
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+    
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        input_ids: torch.Tensor,
+        attention_mask:torch.Tensor,
+        labels_ids: torch.Tensor,
+        label_mask: torch.Tensor,
+        audio_mask: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        
+        batch_size = speech.shape[0]
+        
+        # audio encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+        
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
+
+        if input_ids is not None:
+            input_ids[input_ids == -1] = 0
+            input_ids[input_ids == -100] = 0
+            if hasattr(self.llm.model, "embed_tokens"):
+                inputs_embeds = self.llm.model.embed_tokens(input_ids)
+            elif hasattr(self.llm.model.model, "embed_tokens"):
+                inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+            else:
+                inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+
+            if audio_mask is not None:
+                batch_size, token_num, dims = inputs_embeds.shape
+                _, l, _ = encoder_out.shape
+                encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
+                inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
+                inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
+
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
+        loss = model_outputs.loss
+
+
+        stats = {}
+        with torch.no_grad():
+            preds = torch.argmax(model_outputs.logits, -1)
+            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+            stats["acc"] = acc_att
+
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+    
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+    
+        audio_mask = kwargs.get("audio_mask", None)
+        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
+
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        enc, enc_lens = self.audio_encoder.encode(**batch)
+        with autocast(False):
+            enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+            pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+                                                                               mask=enc_mask,
+                                                                               target_label_length=audio_token_lengths,
+                                                                               )
+
+        return pre_acoustic_embeds, pre_token_length
+
+
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  **kwargs,
+                  ):
+        
+        prompt = kwargs.get("prompt", "Transcribe speech to text.")
+        
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+
+
+        
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+        
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        # adaptor
+        encoder_out = self.adaptor(encoder_out)
+        
+    
+        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+        prompt_ids = tokenizer.encode(prompt_pre)
+        prompt_length = len(prompt_ids)
+        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+        if hasattr(self.llm.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+        elif hasattr(self.llm.model.model, "embed_tokens"):
+            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+        else:
+            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+        inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1)  # [prompt, audio]
+        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
+        
+        # model_outputs = self.llm.generate(
+        #     inputs_embeds=inputs_embeds,
+        #     max_length=kwargs.get("max_length", 200),
+        #     max_new_tokens=kwargs.get("max_new_tokens", 200),
+        #     num_beams=kwargs.get("num_beams", 4),
+        #     do_sample=kwargs.get("do_sample", False),
+        #     min_length=kwargs.get("min_length", 1),
+        #     top_p=kwargs.get("top_p", 1.0),
+        #     repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+        #     length_penalty=kwargs.get("length_penalty", 1.0),
+        #     temperature=kwargs.get("temperature", 1.0),
+        #     attention_mask=attention_mask,
+        #     bos_token_id=tokenizer.bos_token_id,
+        #     eos_token_id=tokenizer.eos_token_id,
+        #     pad_token_id=tokenizer.pad_token_id
+        # )
+
+
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+        preds = torch.argmax(model_outputs.logits, -1)
+        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+
+        text = text[0].split(': ')[-1]
+        text = text.strip()
+        
+        # preds = torch.argmax(model_outputs.logits, -1)
+        
+        ibest_writer = None
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
+        results = []
+        result_i = {"key": key[0], "text": text}
+        results.append(result_i)
+
+        if ibest_writer is not None:
+            ibest_writer["text"][key[0]] = text
+        
+        
+        
+        
+        return results, meta_data
+

--
Gitblit v1.9.1