From 3df109adfccedeb134dea4ba2ea9a2da89872048 Mon Sep 17 00:00:00 2001
From: Isuxiz Slidder <48672727+Isuxiz@users.noreply.github.com>
Date: 星期一, 31 三月 2025 17:51:52 +0800
Subject: [PATCH] Update model.py to fix "IndexError: index 1 is out of bounds for dimension 1 with size 0" (#2454)

---
 funasr/models/llm_asr/model.py |  236 ++++++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 161 insertions(+), 75 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 1e3515f..fa1d2c3 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -168,8 +168,6 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        # import pdb;
-        # pdb.set_trace()
         if len(text_lengths.size()) > 1:
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
@@ -814,7 +812,7 @@
             ibest_writer = self.writer[f"{0 + 1}best_recog"]
 
         results = []
-        response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response)
+        response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
         result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label}
         if loss is not None:
             result_i["loss"] = loss
@@ -990,12 +988,14 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        # import pdb;
+        # import pdb
+        #
         # pdb.set_trace()
         if len(speech_lengths.size()) > 1:
             speech_lengths = speech_lengths[:, 0]
 
-        batch_size, frames, _ = speech.shape
+        batch_size_speech, frames, _ = speech.shape
+        batch_size, token_num = input_ids.shape
 
         with torch.cuda.amp.autocast(enabled=False):
             # audio encoder
@@ -1008,38 +1008,38 @@
         inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
 
         batch_size, token_num, dims = inputs_embeds.shape
-        fbank_mask[fbank_mask < 0] = 0
-        fbank_fake_lens = fbank_mask.sum(-1).to(torch.int32)
-        # _, l, _ = encoder_out.shape
         fake_token_len = kwargs.get("fake_token_len")
         fake_token_len[fake_token_len < 0] = 0
         fbank_beg[fbank_beg < 0] = 0
+
+        speech_idx = 0
         for batch_idx in range(batch_size):
 
             for turn_id in range(fbank_beg.shape[1]):
                 fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
-                if fbank_beg[turn_id] > 0:
+                if fbank_beg_idx > 0:
                     speech_token_len = fake_token_len[batch_idx, turn_id]
-                    speech_token = encoder_out[batch_idx + turn_id, turn_id, :speech_token_len, :]
+                    speech_token = encoder_out[speech_idx, :speech_token_len, :]
 
-            fbank_fake_len = fbank_fake_lens[batch_idx].item()
-            fbank_beg_idx = fbank_beg[batch_idx, 0].item()
-            min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
+                    try:
+                        inputs_embeds[
+                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
+                        ] = speech_token
+                    except Exception as e:
+                        #
+                        logging.error(f"{str(e)}, {traceback.format_exc()}")
+                        logging.info(
+                            f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
+                        )
+                        # import pdb;
+                        # pdb.set_trace()
+                        speech_token_len = encoder_out_lens[speech_idx].item()
+                        speech_token = encoder_out[speech_idx, :speech_token_len, :]
+                        inputs_embeds[
+                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
+                        ] = speech_token
 
-            try:
-                inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
-                    batch_idx, :min_len, :
-                ]
-            except Exception as e:
-                logging.error(f"{str(e)}, {traceback.format_exc()}")
-                logging.info(
-                    f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens[batch_idx].item()}"
-                )
-                fbank_fake_len = encoder_out_lens[batch_idx].item()
-                min_len = min(fbank_fake_len, min_len)
-                inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
-                    batch_idx, :min_len, :
-                ]
+                    speech_idx += 1
 
         with torch.cuda.amp.autocast(
             enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
@@ -1061,12 +1061,19 @@
 
         stats["loss"] = torch.clone(loss.detach())
         stats["batch_size"] = batch_size
-        stats["batch_size_x_frames"] = frames * batch_size
+        stats["batch_size_speech"] = batch_size_speech
+        stats["batch_size_x_frames"] = frames * batch_size_speech
         stats["batch_size_real_frames"] = speech_lengths.sum().item()
         stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
         stats["batch_size_x_tokens"] = token_num * batch_size
         stats["batch_size_real_tokens"] = attention_mask.sum().item()
         stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
+
+        dialog_turns = (fbank_beg > 0).sum(-1)
+        dialog_turns_max = torch.max(dialog_turns).int().item()
+        dialog_turns_avg = dialog_turns.sum().item() / batch_size
+        stats["dialog_turns_max"] = dialog_turns_max
+        stats["dialog_turns_avg"] = dialog_turns_avg
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
@@ -1108,8 +1115,8 @@
         user = contents["user"]
         assistant = contents["assistant"]
         pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
-        input_ids, labels, source_ids, target_ids, fbank, fbank_lens, fbank_mask, fbank_beg = (
-            [],
+
+        input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
             [],
             [],
             [],
@@ -1118,30 +1125,43 @@
             [],
             [],
         )
-
+        input_source_ids = []
         for i, (system_prompt, user_prompt, target_out) in enumerate(zip(system, user, assistant)):
+            if i >= kwargs.get("multiturn_num_max", 5):
+                break
+            if len(input_ids) > kwargs.get("max_token_length", 1500):
 
-            source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
+                break
+
+            if i == 0:
+                source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
+            else:
+                source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
 
             splits = pattern.split(source_input)
-            source_ids_i = []
+            source_ids = []
+            fbank_i = []
             fbank_mask_i = []
-            fbank_beg_i = []
+            fake_token_len_i = 0
+            fbank_beg_i = -1
             fbank_lens_i = []
-            # target_ids_i = []
+            speech, speech_lengths = [], []
             for k, sub_str in enumerate(splits):
                 if not sub_str.startswith("<|startofspeech|>"):
                     sub_token = tokenizer.encode(sub_str)
-                    source_ids_i += sub_token
+                    source_ids += sub_token
                     fbank_mask_i += [0] * len(sub_token)
                 else:
                     sub_str = sub_str.replace("<|startofspeech|>", "").replace(
                         "<|endofspeech|>", ""
                     )
                     if sub_str.startswith("!"):
+                        sub_str = sub_str[1:]
+                        if sub_str.startswith("!"):  # !!bytes
+                            sub_str = eval(sub_str[1:])
                         try:
                             time1 = time.perf_counter()
-                            data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
+                            data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
                             time2 = time.perf_counter()
                             meta_data["load_data"] = f"{time2 - time1:0.3f}"
                         except Exception as e:
@@ -1165,49 +1185,70 @@
 
                         if kwargs.get("permute", True):
                             speech = speech.permute(0, 2, 1)
+                        if speech_lengths > kwargs.get("max_source_length", 5500):
+                            # logging.info(
+                            #     f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
+                            # )
+                            badcase_flag = True
 
                         olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
                         olens = 1 + (olens - 3 + 2 * 1) // 2
-                        sub_token_len = (olens - 1) // 2 + 1
-                        sub_token = [0] * sub_token_len
-                        fbank_beg_i = [len(source_ids_i)]
-                        source_ids_i += sub_token
-                        fbank_mask_i += [1] * len(sub_token)
+                        fake_token_len_i = (olens - 1) // 2 + 1
+                        fake_token = [0] * fake_token_len_i
+                        fbank_beg_i = len(source_ids)
+                        source_ids += fake_token
+                        fbank_mask_i += [1] * len(fake_token)
 
-            source_mask = [-100] * len(source_ids_i)
+            fbank_beg += [fbank_beg_i + len(input_ids)]
+            fake_token_len += [fake_token_len_i]
+            source_mask = [-100] * len(source_ids)
             target_out = f"{target_out}<|im_end|>"
             target_ids = tokenizer.encode(target_out)
-            input_ids += source_ids_i + target_ids
+            input_source_ids = input_ids + source_ids
+            input_ids += source_ids + target_ids
             labels += source_mask + target_ids
             fbank_mask += fbank_mask_i
-            fbank_beg.append(fbank_beg_i)
+            if len(speech) > 0:
+                fbank.append(speech[0, :, :])
+                fbank_lens.append(speech_lengths)
 
         input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
         attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
         labels = torch.tensor(labels, dtype=torch.int64)  # [: self.max_token_length]
-        source_ids = torch.tensor(source_ids_i, dtype=torch.int64)
-        target_ids = torch.tensor(target_ids, dtype=torch.int64)
 
-        fbank = speech[0, :, :]
-        fbank_lens = speech_lengths
+        # fbank = speech[0, :, :]
+        # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
         fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
         fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
+        fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
+        source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
+        target_ids = torch.tensor(target_ids, dtype=torch.int64)
 
+        if len(fbank) > 0:
+            speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
+            speech_lengths = torch.nn.utils.rnn.pad_sequence(
+                fbank_lens, batch_first=True, padding_value=-1
+            )
+        else:
+            speech = []
+            speech_lengths = []
         output = {
-            "speech": fbank[None, :, :],
-            "speech_lengths": fbank_lens[:, None],
+            "speech": speech,
+            "speech_lengths": speech_lengths,
             "fbank_mask": fbank_mask[None, :],
             "fbank_beg": fbank_beg[None,],
-            "input_ids": input_ids[None, :],
-            "attention_mask": attention_mask[None, :],
-            "labels_ids": labels[None, :],
+            "fake_token_len": fake_token_len[None, :],
+            "input_ids": input_ids[None,],
+            "attention_mask": attention_mask[None,],
+            "labels_ids": labels,
             "source_ids": source_ids[None, :],
             "target_ids": target_ids[None, :],
         }
 
         return output
 
-    def inference(
+
+    def inference_prepare(
         self,
         data_in,
         data_lengths=None,
@@ -1229,34 +1270,79 @@
 
         # audio encoder
         speech = batch["speech"]
-        speech_lengths = batch["speech_lengths"][:, 0]
-        # fp16
-        if kwargs.get("fp16", False):
-            speech = speech.to(torch.float16)
-        elif kwargs.get("bf16", False):
-            speech = speech.to(torch.bfloat16)
-        # audio encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if len(speech) > 0:
+            speech_lengths = batch["speech_lengths"][:, 0]
+            # fp16
+            if kwargs.get("fp16", False):
+                speech = speech.to(torch.float16)
+            elif kwargs.get("bf16", False):
+                speech = speech.to(torch.bfloat16)
+            # audio encoder
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        # audio_adaptor
-        encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
+            # audio_adaptor
+            encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
 
         input_ids = batch["input_ids"]
         source_ids = batch["source_ids"]
+        fbank_beg = batch["fbank_beg"]
+        fake_token_len = batch["fake_token_len"]
+
         if not kwargs.get("tearchforing", False):
             input_ids = source_ids
+
         input_ids[input_ids < 0] = 0
         inputs_embeds = self.llm.model.get_input_embeddings()(input_ids)
 
         batch_size, token_num, dims = inputs_embeds.shape
-        fbank_beg = batch["fbank_beg"]
+
+        fake_token_len[fake_token_len < 0] = 0
+        fbank_beg[fbank_beg < 0] = 0
+
+        speech_idx = 0
         for batch_idx in range(batch_size):
 
-            min_len = encoder_out_lens[batch_idx].item()
-            fbank_beg_idx = fbank_beg[batch_idx]
-            inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
-                batch_idx, :min_len, :
-            ]
+            for turn_id in range(fbank_beg.shape[1]):
+                fbank_beg_idx = fbank_beg[batch_idx, turn_id].item()
+                if fbank_beg_idx > 0:
+                    speech_token_len = fake_token_len[batch_idx, turn_id]
+                    speech_token = encoder_out[speech_idx, :speech_token_len, :]
+
+                    try:
+                        inputs_embeds[
+                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
+                        ] = speech_token
+                    except Exception as e:
+                        #
+                        logging.error(f"{str(e)}, {traceback.format_exc()}")
+                        logging.info(
+                            f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, speech_token_len: {speech_token_len}, encoder_out: {encoder_out.shape}, encoder_out_lens: {encoder_out_lens}, fake_token_len: {fake_token_len}, speech_lengths: {speech_lengths}"
+                        )
+                        # import pdb;
+                        # pdb.set_trace()
+                        speech_token_len = encoder_out_lens[speech_idx].item()
+                        speech_token = encoder_out[speech_idx, :speech_token_len, :]
+                        inputs_embeds[
+                            batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
+                        ] = speech_token
+
+                    speech_idx += 1
+        return inputs_embeds, contents, batch, source_ids, meta_data
+    
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+
+        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
+            data_in, data_lengths, key, tokenizer, frontend, **kwargs
+        )
 
         llm_dtype = kwargs.get("llm_dtype", "fp32")
         if llm_dtype == "fp32":
@@ -1266,7 +1352,7 @@
         with torch.cuda.amp.autocast(
             enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
         ):
-            label = contents["assistant"][0]
+            label = contents["assistant"][-1]
             self.llm = self.llm.to(dtype_map[llm_dtype])
             inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
 
@@ -1309,15 +1395,15 @@
             ibest_writer = self.writer[f"{0 + 1}best_recog"]
 
         results = []
-        response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response)
+        response_clean = re.sub(r"[^\w\s\u3000\u4e00-\u9fff]+", "", response)
         result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label}
         if loss is not None:
             result_i["loss"] = loss
         results.append(result_i)
 
         if ibest_writer is not None:
-            ibest_writer["text"][key[0]] = response
-            ibest_writer["label"][key[0]] = label
+            ibest_writer["text"][key[0]] = response.replace("\n", " ")
+            ibest_writer["label"][key[0]] = label.replace("\n", " ")
             ibest_writer["text_tn"][key[0]] = response_clean
 
         return results, meta_data

--
Gitblit v1.9.1