zhifu gao
2024-06-20 e65b1f701abca03bf3a1b5fbb200392aabd38c22
funasr/models/llm_asr/model.py
@@ -988,9 +988,9 @@
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        import pdb
        pdb.set_trace()
        # import pdb
        #
        # pdb.set_trace()
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
@@ -1011,6 +1011,7 @@
        fake_token_len = kwargs.get("fake_token_len")
        fake_token_len[fake_token_len < 0] = 0
        fbank_beg[fbank_beg < 0] = 0
        speech_idx = 0
        for batch_idx in range(batch_size):
@@ -1025,12 +1026,15 @@
                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
                        ] = speech_token
                    except Exception as e:
                        #
                        logging.error(f"{str(e)}, {traceback.format_exc()}")
                        logging.info(
                            f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[speech_idx].item()}"
                            f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
                        )
                        # import pdb;
                        # pdb.set_trace()
                        speech_token_len = encoder_out_lens[speech_idx].item()
                        speech_token = encoder_out[speech_idx, turn_id, :speech_token_len, :]
                        speech_token = encoder_out[speech_idx, :speech_token_len, :]
                        inputs_embeds[
                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
                        ] = speech_token
@@ -1064,6 +1068,12 @@
        stats["batch_size_x_tokens"] = token_num * batch_size
        stats["batch_size_real_tokens"] = attention_mask.sum().item()
        stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
        dialog_turns = (fbank_beg > 0).sum(-1)
        dialog_turns_max = torch.max(dialog_turns).int().item()
        dialog_turns_avg = dialog_turns.sum().item() / batch_size
        stats["dialog_turns_max"] = dialog_turns_max
        stats["dialog_turns_avg"] = dialog_turns_avg
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
@@ -1105,8 +1115,8 @@
        user = contents["user"]
        assistant = contents["assistant"]
        pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
        input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = (
            [],
        input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
            [],
            [],
            [],
@@ -1115,21 +1125,30 @@
            [],
            [],
        )
        input_source_ids = []
        for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
            if i >= kwargs.get("multiturn_num_max", 5):
                break
            if len(input_ids) > kwargs.get("max_token_length", 1500):
            source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
                break
            if i == 0:
                source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
            else:
                source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
            splits = pattern.split(source_input)
            source_ids_i = []
            source_ids = []
            fbank_i = []
            fbank_mask_i = []
            fbank_beg_i = []
            fake_token_len_i = 0
            fbank_beg_i = -1
            fbank_lens_i = []
            # target_ids_i = []
            for k, sub_str in enumerate(splits):
                if not sub_str.startswith("<|startofspeech|>"):
                    sub_token = tokenizer.encode(sub_str)
                    source_ids_i += sub_token
                    source_ids += sub_token
                    fbank_mask_i += [0] * len(sub_token)
                else:
                    sub_str = sub_str.replace("<|startofspeech|>", "").replace(
@@ -1162,42 +1181,57 @@
                        if kwargs.get("permute", True):
                            speech = speech.permute(0, 2, 1)
                        if speech_lengths > kwargs.get("max_source_length", 5500):
                            # logging.info(
                            #     f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
                            # )
                            badcase_flag = True
                        olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
                        olens = 1 + (olens - 3 + 2 * 1) // 2
                        sub_token_len = (olens - 1) // 2 + 1
                        sub_token = [0] * sub_token_len
                        fbank_beg_i = [len(source_ids_i)]
                        source_ids_i += sub_token
                        fbank_mask_i += [1] * len(sub_token)
                        fake_token_len_i = (olens - 1) // 2 + 1
                        fake_token = [0] * fake_token_len_i
                        fbank_beg_i = len(source_ids)
                        source_ids += fake_token
                        fbank_mask_i += [1] * len(fake_token)
            source_mask = [-100] * len(source_ids_i)
            fbank_beg += [fbank_beg_i + len(input_ids)]
            fake_token_len += [fake_token_len_i]
            source_mask = [-100] * len(source_ids)
            target_out = f"{target_out}<|im_end|>"
            target_ids = tokenizer.encode(target_out)
            input_ids += source_ids_i + target_ids
            input_source_ids = input_ids + source_ids
            input_ids += source_ids + target_ids
            labels += source_mask + target_ids
            fbank.append(speech[0, :, :])
            fbank_mask += fbank_mask_i
            fbank_beg.append(fbank_beg_i)
            fbank_lens.append(speech_lengths)
        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
        attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
        labels = torch.tensor(labels, dtype=torch.int64)  # [: self.max_token_length]
        source_ids = torch.tensor(source_ids_i, dtype=torch.int64)
        target_ids = torch.tensor(target_ids, dtype=torch.int64)
        fbank = speech[0, :, :]
        fbank_lens = speech_lengths
        # fbank = speech[0, :, :]
        # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
        fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
        fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
        fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
        source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
        target_ids = torch.tensor(target_ids, dtype=torch.int64)
        speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
        speech_lengths = torch.nn.utils.rnn.pad_sequence(
            fbank_lens, batch_first=True, padding_value=-1
        )
        output = {
            "speech": fbank[None, :, :],
            "speech_lengths": fbank_lens[:, None],
            "speech": speech,
            "speech_lengths": speech_lengths,
            "fbank_mask": fbank_mask[None, :],
            "fbank_beg": fbank_beg[None,],
            "input_ids": input_ids[None, :],
            "attention_mask": attention_mask[None, :],
            "labels_ids": labels[None, :],
            "fake_token_len": fake_token_len[None, :],
            "input_ids": input_ids[None,],
            "attention_mask": attention_mask[None,],
            "labels_ids": labels,
            "source_ids": source_ids[None, :],
            "target_ids": target_ids[None, :],
        }
@@ -1240,20 +1274,48 @@
        input_ids = batch["input_ids"]
        source_ids = batch["source_ids"]
        fbank_beg = batch["fbank_beg"]
        fake_token_len = batch["fake_token_len"]
        if not kwargs.get("tearchforing", False):
            input_ids = source_ids
        input_ids[input_ids < 0] = 0
        inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
        batch_size, token_num, dims = inputs_embeds.shape
        fbank_beg = batch["fbank_beg"]
        fake_token_len[fake_token_len < 0] = 0
        fbank_beg[fbank_beg < 0] = 0
        speech_idx = 0
        for batch_idx in range(batch_size):
            min_len = encoder_out_lens[batch_idx].item()
            fbank_beg_idx = fbank_beg[batch_idx]
            inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
                batch_idx, :min_len, :
            ]
            for turn_id in range(fbank_beg.shape[1]):
                fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
                if fbank_beg_idx > 0:
                    speech_token_len = fake_token_len[batch_idx, turn_id]
                    speech_token = encoder_out[speech_idx, :speech_token_len, :]
                    try:
                        inputs_embeds[
                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
                        ] = speech_token
                    except Exception as e:
                        #
                        logging.error(f"{str(e)}, {traceback.format_exc()}")
                        logging.info(
                            f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
                        )
                        # import pdb;
                        # pdb.set_trace()
                        speech_token_len = encoder_out_lens[speech_idx].item()
                        speech_token = encoder_out[speech_idx, :speech_token_len, :]
                        inputs_embeds[
                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
                        ] = speech_token
                    speech_idx += 1
        llm_dtype = kwargs.get("llm_dtype", "fp32")
        if llm_dtype == "fp32":
@@ -1263,7 +1325,7 @@
        with torch.cuda.amp.autocast(
            enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
        ):
            label = contents["assistant"][0]
            label = contents["assistant"][-1]
            self.llm = self.llm.to(dtype_map[llm_dtype])
            inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
@@ -1313,8 +1375,8 @@
        results.append(result_i)
        if ibest_writer is not None:
            ibest_writer["text"][key[0]] = response
            ibest_writer["label"][key[0]] = label
            ibest_writer["text"][key[0]] = response.replace("\n", " ")
            ibest_writer["label"][key[0]] = label.replace("\n", " ")
            ibest_writer["text_tn"][key[0]] = response_clean
        return results, meta_data