zhifu gao
2024-06-28 8c87a9d8a7c2f136053476670a9a83980f142aec
funasr/models/llm_asr/model.py
@@ -1145,6 +1145,7 @@
            fake_token_len_i = 0
            fbank_beg_i = -1
            fbank_lens_i = []
            speech, speech_lengths = [], []
            for k, sub_str in enumerate(splits):
                if not sub_str.startswith("<|startofspeech|>"):
                    sub_token = tokenizer.encode(sub_str)
@@ -1155,9 +1156,12 @@
                        "<|endofspeech|>", ""
                    )
                    if sub_str.startswith("!"):
                        sub_str = sub_str[1:]
                        if sub_str.startswith("!"):  # !!bytes
                            sub_str = eval(sub_str[1:])
                        try:
                            time1 = time.perf_counter()
                            data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
                            data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
                            time2 = time.perf_counter()
                            meta_data["load_data"] = f"{time2 - time1:0.3f}"
                        except Exception as e:
@@ -1203,9 +1207,10 @@
            input_source_ids = input_ids + source_ids
            input_ids += source_ids + target_ids
            labels += source_mask + target_ids
            fbank.append(speech[0, :, :])
            fbank_mask += fbank_mask_i
            fbank_lens.append(speech_lengths)
            if len(speech) > 0:
                fbank.append(speech[0, :, :])
                fbank_lens.append(speech_lengths)
        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
        attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
@@ -1219,10 +1224,14 @@
        source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
        target_ids = torch.tensor(target_ids, dtype=torch.int64)
        speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
        speech_lengths = torch.nn.utils.rnn.pad_sequence(
            fbank_lens, batch_first=True, padding_value=-1
        )
        if len(fbank) > 0:
            speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
            speech_lengths = torch.nn.utils.rnn.pad_sequence(
                fbank_lens, batch_first=True, padding_value=-1
            )
        else:
            speech = []
            speech_lengths = []
        output = {
            "speech": speech,
            "speech_lengths": speech_lengths,
@@ -1238,7 +1247,8 @@
        return output
    def inference(
    def inference_prepare(
        self,
        data_in,
        data_lengths=None,
@@ -1260,17 +1270,18 @@
        # audio encoder
        speech = batch["speech"]
        speech_lengths = batch["speech_lengths"][:, 0]
        # fp16
        if kwargs.get("fp16", False):
            speech = speech.to(torch.float16)
        elif kwargs.get("bf16", False):
            speech = speech.to(torch.bfloat16)
        # audio encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if len(speech) > 0:
            speech_lengths = batch["speech_lengths"][:, 0]
            # fp16
            if kwargs.get("fp16", False):
                speech = speech.to(torch.float16)
            elif kwargs.get("bf16", False):
                speech = speech.to(torch.bfloat16)
            # audio encoder
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # audio_adaptor
        encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
            # audio_adaptor
            encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
        input_ids = batch["input_ids"]
        source_ids = batch["source_ids"]
@@ -1316,6 +1327,22 @@
                        ] = speech_token
                    speech_idx += 1
        return inputs_embeds, contents, batch, source_ids, meta_data
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
            data_in, data_lengths, key, tokenizer, frontend, **kwargs
        )
        llm_dtype = kwargs.get("llm_dtype", "fp32")
        if llm_dtype == "fp32":