游雁
2024-06-06 27256ed429c95ed8868a01f8555610393dd7b3a1
auto frontend
2个文件已修改
3个文件已添加
630 ■■■■■ 已修改文件
funasr/datasets/openai_datasets/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/openai_datasets/datasets.py 216 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/openai_datasets/index_ds.py 95 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/sense_voice_datasets/datasets.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py 318 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/openai_datasets/__init__.py
funasr/datasets/openai_datasets/datasets.py
New file
@@ -0,0 +1,216 @@
import logging
import re
import torch
import random
import traceback
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@tables.register("dataset_classes", "OpenAIDataset")
class OpenAIDataset(torch.utils.data.Dataset):
    """
    SenseVoiceDataset
    """
    def __init__(
        self,
        path,
        index_ds: str = None,
        frontend=None,
        tokenizer=None,
        int_pad_value: int = -1,
        float_pad_value: float = 0.0,
        **kwargs,
    ):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(
                **kwargs.get("preprocessor_speech_conf")
            )
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.int_pad_value = int_pad_value
        self.float_pad_value = float_pad_value
        self.sos = kwargs.get("sos", "<|startoftranscript|>")
        self.eos = kwargs.get("eos", "<|endoftext|>")
        self.batch_size = kwargs.get("batch_size")
        self.batch_type = kwargs.get("batch_type")
        self.prompt_ids_len = 0
        self.retry = kwargs.get("retry", 5)
        self.permute = False
        from funasr.frontends.whisper_frontend import WhisperFrontend
        if isinstance(self.frontend, WhisperFrontend):
            self.permute = True
        self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
    def get_target_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
    def __len__(self):
        return len(self.index_ds)
    def __getitem__(self, index):
        # import pdb;
        # pdb.set_trace()
        output = None
        for idx in range(self.retry):
            if idx == 0:
                index_cur = index
            else:
                index_cur = torch.randint(0, len(self.index_ds), ()).item()
            item = self.index_ds[index_cur]
            system = item["system"]
            user = item["user"]
            assistant = item["assistant"]
            input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
            for i, (system_prompt, user_prompt, target_out) in enumerate(
                zip(system, user, assistant)
            ):
                source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
                splits = self.pattern.split(source_input)
                source_ids = []
                fbank_mask_i = []
                fbank_beg_i = []
                fbank_lens_i = []
                for k, sub_str in enumerate(splits):
                    if not sub_str.startswith("<|startofspeech|>"):
                        sub_token = self.tokenizer.encode(sub_str)
                        source_ids += sub_token
                        fbank_mask_i += [0] * len(sub_token)
                    else:
                        sub_str = sub_str.replace("<|startofspeech|>", "").replace(
                            "<|endofspeech|>", ""
                        )
                        if sub_str.startswith("!"):
                            data_src = load_audio_text_image_video(sub_str[1:], fs=self.fs)
                            speech, speech_lengths = extract_fbank(
                                data_src,
                                data_type=self.data_type,
                                frontend=self.frontend,
                                is_final=True,
                            )  # speech: [b, T, d]
                            if self.permute:
                                speech = speech.permute(0, 2, 1)
                            if speech_lengths > self.batch_size:
                                continue
                            fbank_lens = speech_lengths[0].item()
                            olens = 1 + (fbanks_len - 3 + 2 * 1) // 2
                            olens = 1 + (olens - 3 + 2 * 1) // 2
                            sub_token_len = (olens - 1) // 2 + 1
                            sub_token = [0] * sub_token_len[0]
                            fbank_beg_i = [len(source_ids)]
                            source_ids += sub_token
                            fbank_mask_i += [1] * len(sub_token)
                source_mask = [-100] * len(source_ids)
                target_out = f"{target_out}<|im_end|>"
                target_ids = tokenizer.encode(target_out)
                input_ids += source_ids + target_ids
                labels += source_mask + target_ids
                fbank_mask += fbank_mask_i
                fbank_beg.append(fbank_beg_i)
            input_ids = torch.tensor(input_ids, dtype=torch.int64)
            attention_mask = torch.tensor([len(input_ids)], dtype=torch.int32)
            labels = torch.tensor(labels, dtype=torch.int64)
            fbank = speech[0, :, :]
            fbank_lens = speech_lengths
            fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
            fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
            output = {
                "speech": fbank,
                "speech_lengths": fbank_lens,
                "fbank_mask": fbank_mask,
                "fbank_beg": fbank_beg,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels_ids": labels,
            }
            break
        return output
    def collator(self, samples: list = None):
        outputs = {}
        for sample in samples:
            if sample is None:
                continue
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(
                    data_list, batch_first=True, padding_value=pad_value
                )
        if self.batch_type != "example":
            for i in range(10):
                outputs = self._filter_badcase(outputs, i=i)
        return outputs
    def _filter_badcase(self, outputs, i=0):
        b, t, _ = outputs["speech"].shape
        if b * t > self.batch_size * 1.25:
            beg = torch.randint(0, 2, ()).item()
            if b < 2:
                beg = 0
            logging.info(
                f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}"
            )
            for key, data_list in outputs.items():
                outputs[key] = outputs[key][beg : beg + b : 2]
            speech_lengths_max = outputs["speech_lengths"].max().item()
            outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :]
            text_lengths_max = outputs["text_lengths"].max().item()
            outputs["text"] = outputs["text"][:, :text_lengths_max]
            target_mask_lengths_max = outputs["target_mask_lengths"].max().item()
            outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max]
        return outputs
funasr/datasets/openai_datasets/index_ds.py
New file
@@ -0,0 +1,95 @@
import os
import json
import torch
import logging
import librosa
import random
import torch.distributed as dist
from funasr.register import tables
@tables.register("index_ds_classes", "OpenAIIndexDSJsonl")
class OpenAIIndexDSJsonl(torch.utils.data.Dataset):  # torch.utils.data.Dataset
    def __init__(self, path: str, **kwargs):
        super().__init__()
        self.max_source_length = kwargs.get("max_source_length", 2048)
        self.min_source_length = kwargs.get("min_source_length", 0)
        self.max_target_length = kwargs.get("max_target_length", 2048)
        self.min_target_length = kwargs.get("min_target_length", 0)
        self.max_token_length = kwargs.get("max_token_length", 2200)
        is_training = kwargs.get("is_training", True)
        if not (path.endswith(".jsonl") or path.endswith(".json")):
            # jsonl list file
            data_split_num = kwargs.get("data_split_num", 1)
            data_split_i = kwargs.get("data_split_i", 0)
            if not is_training:
                data_split_num = 1
                data_split_i = 0
            with open(path, encoding="utf-8") as fin:
                file_list_all = fin.readlines()
                num_per_slice = (len(file_list_all) - 1) // data_split_num + 1  # 16
                file_list = file_list_all[
                    data_split_i * num_per_slice : (data_split_i + 1) * num_per_slice
                ]
                logging.info(
                    f"is_training: {is_training}, data_split_num: {data_split_num}, data_split_i: {data_split_i}, \nfile_list: {file_list}, \nfile_list_all: {file_list_all}"
                )
        else:
            file_list = [path]
        contents = []
        for file_json in file_list:
            with open(file_json.strip(), encoding="utf-8") as fin:
                for line in fin:
                    data = json.loads(line.strip())["messages"]
                    system, user, assistant = [], [], []
                    for i, item in enumerate(data):
                        role = item["role"]
                        content = item["content"]
                        if role == "system":
                            system.append(content)
                        elif role == "user":
                            user.append(content)
                        elif role == "assistant":
                            assistant.append(content)
                    system = system * len(user)
                    contents_i = {"system": system, "user": user, "assistant": assistant}
                    contents.append(contents_i)
        self.contents = contents
        logging.info("total_num of samplers: {}, {}".format(len(self.contents), path))
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, index):
        data = self.contents[index]
        return data
    def get_source_len(self, data_dict):
        return len(data_dict["system"]) + len(data_dict["user"])
    def get_target_len(self, data_dict):
        return len(data_dict["assistant"])
if __name__ == "__main__":
    index_ds = OpenAIIndexDSJsonl(
        path="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl"
    )
    print(index_ds.contents)
    pass
funasr/datasets/sense_voice_datasets/datasets.py
@@ -1,5 +1,6 @@
import logging
import re
import torch
import random
import traceback
funasr/models/llm_asr/model.py
@@ -341,3 +341,321 @@
            ibest_writer["text"][key[0]] = text
        return results, meta_data
@tables.register("model_classes", "LLMASR2")
class LLMASR2(nn.Module):
    """ """
    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        audio_encoder: str = None,
        audio_encoder_conf: dict = None,
        audio_adaptor: str = None,
        audio_adaptor_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        llm: str = None,
        llm_conf: dict = None,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        # audio encoder
        hub = audio_encoder_conf.get("hub", None)
        if hub == "ms":
            from funasr import AutoModel
            model = AutoModel(model=audio_encoder, model_revision="master")
            # frontend = model.kwargs.get("frontend")
            audio_encoder_output_size = model.model.encoder_output_size
            audio_encoder = model.model.model.encoder
            # self.frontend = frontend
        elif hub == "hf":
            pass
        else:
            encoder_class = tables.encoder_classes.get(audio_encoder)
            audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
            audio_encoder_output_size = audio_encoder.output_size()
        freeze = audio_encoder_conf.get("freeze", True)
        if freeze:
            for name, param in audio_encoder.named_parameters():
                param.requires_grad = False
            audio_encoder.eval()
        self.audio_encoder = audio_encoder
        # llm
        hub = llm_conf.get("hub", "hf")
        self.llm = None
        # if hub == "hf":
        #     from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
        #
        #     init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
        #
        #     model = AutoModelForCausalLM.from_pretrained(
        #         init_param_path,
        #         load_in_8bit=None,
        #         device_map=None,
        #         use_cache=None,
        #     )
        #     freeze = llm_conf.get("freeze", True)
        #     if freeze:
        #         for name, param in model.named_parameters():
        #             param.requires_grad = False
        #         model.eval()
        #     self.llm = model
        # adaptor
        adaptor_class = tables.adaptor_classes.get(audio_adaptor)
        audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
        audio_adaptor = adaptor_class(**audio_adaptor_conf)
        self.audio_adaptor = audio_adaptor
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.specaug = specaug
        self.normalize = normalize
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        self.error_calculator = None
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels_ids: torch.Tensor,
        fbank_beg: torch.Tensor,
        fbank_mask: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # audio_adaptor
        encoder_out = self.audio_adaptor(encoder_out)
        input_ids[input_ids == -1] = 0
        input_ids[input_ids == -100] = 0
        if hasattr(self.llm.model, "embed_tokens"):
            inputs_embeds = self.llm.model.embed_tokens(input_ids)
        elif hasattr(self.llm.model.model, "embed_tokens"):
            inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
        else:
            inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
        batch_size, token_num, dims = inputs_embeds.shape
        _, l, _ = encoder_out.shape
        for batch_idx in range(batch_size):
            fbank_beg_idx = fbank_beg[batch_idx, 0].item()
            inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + l, :] = encoder_out[
                batch_idx, :l, :
            ]
        model_outputs = self.llm(
            inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
        )
        loss = model_outputs.loss
        stats = {}
        with torch.no_grad():
            preds = torch.argmax(model_outputs.logits, -1)
            acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
            stats["acc"] = acc_att
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ):
        speech = speech.permute(0, 2, 1)
        res = self.audio_encoder(speech)
        if isinstance(res, (list, tuple)):
            encoder_out, encoder_out_lens = res[0], res[1]
        else:
            encoder_out, encoder_out_lens = res, speech_lengths
        return encoder_out, encoder_out_lens
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        prompt = kwargs.get("prompt", "Transcribe speech to text.")
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(
                data_in,
                fs=frontend.fs,
                audio_fs=kwargs.get("fs", 16000),
                data_type=kwargs.get("data_type", "sound"),
                tokenizer=tokenizer,
            )
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(
                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
            )
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = (
                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            )
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # adaptor
        encoder_out = self.audio_adaptor(encoder_out)
        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
        prompt_ids = tokenizer.encode(prompt_pre)
        prompt_length = len(prompt_ids)
        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
        if hasattr(self.llm.model, "embed_tokens"):
            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
        elif hasattr(self.llm.model.model, "embed_tokens"):
            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
        else:
            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
        inputs_embeds = torch.cat(
            (inputs_embeds[None, :, :], encoder_out), dim=1
        )  # [prompt, audio]
        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
            kwargs["device"]
        )
        preds = self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_length=kwargs.get("max_length", 200),
            max_new_tokens=kwargs.get("max_new_tokens", 200),
            num_beams=kwargs.get("num_beams", 4),
            do_sample=kwargs.get("do_sample", False),
            min_length=kwargs.get("min_length", 1),
            top_p=kwargs.get("top_p", 1.0),
            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
            length_penalty=kwargs.get("length_penalty", 1.0),
            temperature=kwargs.get("temperature", 1.0),
            attention_mask=attention_mask,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
        text = text[0].split(": ")[-1]
        text = text.strip()
        # preds = torch.argmax(model_outputs.logits, -1)
        ibest_writer = None
        if kwargs.get("output_dir") is not None:
            if not hasattr(self, "writer"):
                self.writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = self.writer[f"{0 + 1}best_recog"]
        results = []
        result_i = {"key": key[0], "text": text}
        results.append(result_i)
        if ibest_writer is not None:
            ibest_writer["text"][key[0]] = text
        return results, meta_data