From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/datasets/openai_datasets/datasets.py |   57 +++++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 39 insertions(+), 18 deletions(-)

diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py
index 3c2a957..d670708 100644
--- a/funasr/datasets/openai_datasets/datasets.py
+++ b/funasr/datasets/openai_datasets/datasets.py
@@ -283,10 +283,11 @@
 
         self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
         # self.kwargs = kwargs
-        self.max_token_length = kwargs.get("max_token_length", 1024)
+        self.max_token_length = kwargs.get("max_token_length", 1500)
         self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5)
         self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500)
         self.multiturn_num_max = kwargs.get("multiturn_num_max", 5)
+        self.max_source_length = kwargs.get("max_source_length", 3000)
 
     def get_source_len(self, index):
         item = self.index_ds[index]
@@ -300,7 +301,8 @@
         return len(self.index_ds)
 
     def __getitem__(self, index):
-        # import pdb;
+        # import pdb
+        #
         # pdb.set_trace()
 
         output = None
@@ -318,13 +320,27 @@
             user = item["user"]
             assistant = item["assistant"]
 
-            input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
+            input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
+                [],
+                [],
+                [],
+                [],
+                [],
+                [],
+                [],
+            )
 
             for i, (system_prompt, user_prompt, target_out) in enumerate(
                 zip(system, user, assistant)
             ):
                 if i >= self.multiturn_num_max:
                     break
+                if len(input_ids) > self.max_token_length:
+                    logging.info(
+                        f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
+                    )
+                    break
+
                 if i == 0:
                     source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
                 else:
@@ -336,7 +352,8 @@
                 source_ids = []
                 fbank_i = []
                 fbank_mask_i = []
-                fbank_beg_i = []
+                fake_token_len_i = 0
+                fbank_beg_i = -1
                 fbank_lens_i = []
                 for k, sub_str in enumerate(splits):
                     if not sub_str.startswith("<|startofspeech|>"):
@@ -362,6 +379,11 @@
                                 frontend=self.frontend,
                                 is_final=True,
                             )  # speech: [b, T, d]
+                            if speech_lengths > self.max_source_length:
+                                logging.info(
+                                    f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
+                                )
+                                badcase_flag = True
                             if self.permute:
                                 speech = speech.permute(0, 2, 1)
                             # if speech_lengths > self.batch_size:
@@ -369,14 +391,17 @@
 
                             olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
                             olens = 1 + (olens - 3 + 2 * 1) // 2
-                            sub_token_len = (olens - 1) // 2 + 1
-                            sub_token = [0] * sub_token_len
-                            fbank_beg_i = [len(source_ids)]
-                            source_ids += sub_token
-                            fbank_mask_i += [1] * len(sub_token)
+                            fake_token_len_i = (olens - 1) // 2 + 1
+                            fake_token = [0] * fake_token_len_i
+                            fbank_beg_i = len(source_ids)
+                            source_ids += fake_token
+                            fbank_mask_i += [1] * len(fake_token)
 
                 if badcase_flag:
                     continue
+
+                fbank_beg += [fbank_beg_i + len(input_ids)]
+                fake_token_len += [fake_token_len_i]
                 source_mask = [-100] * len(source_ids)
                 target_out = f"{target_out}<|im_end|>"
                 target_ids = self.tokenizer.encode(target_out)
@@ -384,31 +409,27 @@
                 labels += source_mask + target_ids
                 fbank.append(speech[0, :, :])
                 fbank_mask += fbank_mask_i
-                if len(fbank_beg_i) < 1:
-                    fbank_beg_i = [-1]
-                fbank_beg += fbank_beg_i
+                fbank_lens.append(speech_lengths)
 
-            if len(input_ids) > self.max_token_length:
-                logging.info(
-                    f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
-                )
-                badcase_flag = True
             if badcase_flag:
                 continue
+
             input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
             attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
             labels = torch.tensor(labels, dtype=torch.int64)  # [: self.max_token_length]
 
             # fbank = speech[0, :, :]
-            fbank_lens = speech_lengths
+            # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
             fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
             fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
+            fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
 
             output = {
                 "speech": fbank,
                 "speech_lengths": fbank_lens,
                 "fbank_mask": fbank_mask,
                 "fbank_beg": fbank_beg,
+                "fake_token_len": fake_token_len,
                 "input_ids": input_ids,
                 "attention_mask": attention_mask,
                 "labels_ids": labels,

--
Gitblit v1.9.1