From e65b1f701abca03bf3a1b5fbb200392aabd38c22 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 20 六月 2024 17:09:33 +0800
Subject: [PATCH] Dev gzf deepspeed (#1833)

---
 funasr/datasets/openai_datasets/datasets.py |   20 ++++++++++++++------
 1 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py
index 04ddcfd..d670708 100644
--- a/funasr/datasets/openai_datasets/datasets.py
+++ b/funasr/datasets/openai_datasets/datasets.py
@@ -283,10 +283,11 @@
 
         self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
         # self.kwargs = kwargs
-        self.max_token_length = kwargs.get("max_token_length", 1024)
+        self.max_token_length = kwargs.get("max_token_length", 1500)
         self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5)
         self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500)
         self.multiturn_num_max = kwargs.get("multiturn_num_max", 5)
+        self.max_source_length = kwargs.get("max_source_length", 3000)
 
     def get_source_len(self, index):
         item = self.index_ds[index]
@@ -334,6 +335,12 @@
             ):
                 if i >= self.multiturn_num_max:
                     break
+                if len(input_ids) > self.max_token_length:
+                    logging.info(
+                        f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
+                    )
+                    break
+
                 if i == 0:
                     source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
                 else:
@@ -372,6 +379,11 @@
                                 frontend=self.frontend,
                                 is_final=True,
                             )  # speech: [b, T, d]
+                            if speech_lengths > self.max_source_length:
+                                logging.info(
+                                    f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
+                                )
+                                badcase_flag = True
                             if self.permute:
                                 speech = speech.permute(0, 2, 1)
                             # if speech_lengths > self.batch_size:
@@ -399,13 +411,9 @@
                 fbank_mask += fbank_mask_i
                 fbank_lens.append(speech_lengths)
 
-            if len(input_ids) > self.max_token_length:
-                logging.info(
-                    f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
-                )
-                badcase_flag = True
             if badcase_flag:
                 continue
+
             input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
             attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
             labels = torch.tensor(labels, dtype=torch.int64)  # [: self.max_token_length]

--
Gitblit v1.9.1