游雁
2024-06-08 d94821bbd6f0c53a86724e2c896df6d062432492
fix bug
1个文件已修改
2个文件已添加
279 ■■■■■ 已修改文件
examples/industrial_data_pretraining/llm_asr/demo_speech2text.py 34 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/llm_asr/infer_speech2text.sh 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr/model.py 236 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/llm_asr/demo_speech2text.py
New file
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(
    model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    vad_kwargs={"max_single_segment_time": 60000},
    punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
    # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
res = model.generate(
    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
    cache={},
)
print(res)
""" can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
for batch_idx, fbank_dict in enumerate(fbanks):
    res = model.generate(**fbank_dict)
    print(res)
"""
examples/industrial_data_pretraining/llm_asr/infer_speech2text.sh
New file
@@ -0,0 +1,9 @@
python funasr/bin/inference.py \
--config-path="/nfs/zhifu.gzf/ckpt/llm_asr_nar_exp1" \
--config-name="config.yaml" \
++init_param="/nfs/zhifu.gzf/ckpt/llm_asr_nar_exp1/model.pt.ep5" \
++input="/Users/zhifu/funasr1.0/test_local/data_tmp/tmp_wav_10.jsonl" \
++output_dir="/nfs/zhifu.gzf/ckpt/llm_asr_nar_exp1/inference/aishell2-dev_ios-funasr" \
++device="cpu"
funasr/models/llm_asr/model.py
@@ -18,6 +18,7 @@
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
@tables.register("model_classes", "LLMASR")
@@ -488,8 +489,6 @@
            fbank_fake_len = fbank_fake_lens[batch_idx].item()
            fbank_beg_idx = fbank_beg[batch_idx, 0].item()
            min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
            fbank_fake_len = encoder_out_lens[batch_idx].item()
            min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
            try:
                inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
                    batch_idx, :min_len, :
@@ -506,6 +505,7 @@
                ]
        labels_ids[labels_ids == -1] = -100
        model_outputs = self.llm(
            inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
        )
@@ -532,6 +532,111 @@
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def data_template(self, data_in):
        system, user, assistant = [], [], []
        for i, item in enumerate(data):
            role = item["role"]
            content = item["content"]
            if role == "system":
                system.append(content)
            elif role == "user":
                user.append(content)
            elif role == "assistant":
                assistant.append(content)
        system = system * len(user)
        contents = {
            "system": system,
            "user": user,
            "assistant": assistant,
        }
        return contents
    def data_load_speech(self, contents: dict, tokenizer, frontend, **kwargs):
        system = contents["system"]
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
        input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
        for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
            source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
            splits = pattern.split(source_input)
            source_ids = []
            fbank_mask_i = []
            fbank_beg_i = []
            fbank_lens_i = []
            for k, sub_str in enumerate(splits):
                if not sub_str.startswith("<|startofspeech|>"):
                    sub_token = tokenizer.encode(sub_str)
                    source_ids += sub_token
                    fbank_mask_i += [0] * len(sub_token)
                else:
                    sub_str = sub_str.replace("<|startofspeech|>", "").replace(
                        "<|endofspeech|>", ""
                    )
                    if sub_str.startswith("!"):
                        try:
                            data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
                        except Exception as e:
                            logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
                        speech, speech_lengths = extract_fbank(
                            data_src,
                            data_type=kwargs.get("data_type", "sound"),
                            frontend=frontend,
                            is_final=True,
                        )  # speech: [b, T, d]
                        if kwargs.get("permute", True):
                            speech = speech.permute(0, 2, 1)
                        olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
                        olens = 1 + (olens - 3 + 2 * 1) // 2
                        sub_token_len = (olens - 1) // 2 + 1
                        sub_token = [0] * sub_token_len
                        fbank_beg_i = [len(source_ids)]
                        source_ids += sub_token
                        fbank_mask_i += [1] * len(sub_token)
            source_mask = [-100] * len(source_ids)
            target_out = f"{target_out}<|im_end|>"
            target_ids = tokenizer.encode(target_out)
            input_ids += source_ids + target_ids
            labels += source_mask + target_ids
            fbank_mask += fbank_mask_i
            fbank_beg.append(fbank_beg_i)
        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
        attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
        labels = torch.tensor(labels, dtype=torch.int64)  # [: self.max_token_length]
        source_ids = torch.tensor(source_ids, dtype=torch.int64)
        target_ids = torch.tensor(target_ids, dtype=torch.int64)
        fbank = speech[0, :, :]
        fbank_lens = speech_lengths
        fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
        fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
        output = {
            "speech": fbank[None, :, :],
            "speech_lengths": fbank_lens[:, None],
            "fbank_mask": fbank_mask[None, :],
            "fbank_beg": fbank_beg[None,],
            "input_ids": input_ids[None, :],
            "attention_mask": attention_mask[None, :],
            "labels_ids": labels[None, :],
            "source_ids": source_ids[None, :],
            "target_ids": target_ids[None, :],
        }
        return output
    def inference(
        self,
        data_in,
@@ -542,92 +647,54 @@
        **kwargs,
    ):
        prompt = kwargs.get("prompt", "Transcribe speech to text.")
        meta_data = {}
        prompt = kwargs.get("prompt", None)
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(
                data_in,
                fs=frontend.fs,
                audio_fs=kwargs.get("fs", 16000),
                data_type=kwargs.get("data_type", "sound"),
                tokenizer=tokenizer,
        contents = self.data_template(data_in)
        output = self.data_load_speech(contents, tokenizer, frontend, **kwargs)
        batch = to_device(output, kwargs["device"])
        # audio encoder
        speech = batch["speech"]
        speech_lengths = batch["speech_lengths"][:, 0]
        encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
        # audio_adaptor
        encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
        input_ids = batch["input_ids"]
        source_ids = batch["source_ids"]
        if kwargs.get("tearchforing", False):
            input_ids = source_ids
        input_ids[input_ids < 0] = 0
        inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
        batch_size, token_num, dims = inputs_embeds.shape
        fbank_beg = batch["fbank_beg"]
        for batch_idx in range(batch_size):
            min_len = encoder_out_lens[batch_idx].item()
            fbank_beg_idx = fbank_beg[batch_idx]
            inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
                batch_idx, :min_len, :
            ]
        if not kwargs.get("tearchforing", False):
            generated_ids = self.llm.generate(
                inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_length", 512)
            )
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(
                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
            )
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = (
                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            )
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # adaptor
        encoder_out = self.audio_adaptor(encoder_out)
        prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
        prompt_ids = tokenizer.encode(prompt_pre)
        prompt_length = len(prompt_ids)
        prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
        if hasattr(self.llm.model, "embed_tokens"):
            inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
        elif hasattr(self.llm.model.model, "embed_tokens"):
            inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
        else:
            inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
        inputs_embeds = torch.cat(
            (inputs_embeds[None, :, :], encoder_out), dim=1
        )  # [prompt, audio]
        attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
            kwargs["device"]
        )
        preds = self.llm.generate(
            inputs_embeds=inputs_embeds,
            max_length=kwargs.get("max_length", 200),
            max_new_tokens=kwargs.get("max_new_tokens", 200),
            num_beams=kwargs.get("num_beams", 4),
            do_sample=kwargs.get("do_sample", False),
            min_length=kwargs.get("min_length", 1),
            top_p=kwargs.get("top_p", 1.0),
            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
            length_penalty=kwargs.get("length_penalty", 1.0),
            temperature=kwargs.get("temperature", 1.0),
            attention_mask=attention_mask,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
        text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
        text = text[0].split(": ")[-1]
        text = text.strip()
        # preds = torch.argmax(model_outputs.logits, -1)
            generated_ids = [
                output_ids[len(input_id) :]
                for input_id, output_ids in zip(input_ids, generated_ids)
            ]
            response = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
            )[0]
            label = contents["assistant"][0]
        ibest_writer = None
        if kwargs.get("output_dir") is not None:
@@ -636,10 +703,11 @@
            ibest_writer = self.writer[f"{0 + 1}best_recog"]
        results = []
        result_i = {"key": key[0], "text": text}
        result_i = {"key": key[0], "text": response, "label": label}
        results.append(result_i)
        if ibest_writer is not None:
            ibest_writer["text"][key[0]] = text
            ibest_writer["label"][key[0]] = label
        return results, meta_data