From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/datasets/openai_datasets/datasets.py | 65 +++++++++++++++++++++++---------
1 files changed, 46 insertions(+), 19 deletions(-)
diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py
index 7300b9d..d670708 100644
--- a/funasr/datasets/openai_datasets/datasets.py
+++ b/funasr/datasets/openai_datasets/datasets.py
@@ -283,10 +283,11 @@
self.pattern = re.compile(r"(<\|startofspeech\|>.*?<\|endofspeech\|>)")
# self.kwargs = kwargs
- self.max_token_length = kwargs.get("max_token_length", 1024)
+ self.max_token_length = kwargs.get("max_token_length", 1500)
self.batch_size_scale_ratio_max = kwargs.get("batch_size_scale_ratio_max", 1.5)
self.batch_size_token_max = kwargs.get("batch_size_token_max", 2500)
self.multiturn_num_max = kwargs.get("multiturn_num_max", 5)
+ self.max_source_length = kwargs.get("max_source_length", 3000)
def get_source_len(self, index):
item = self.index_ds[index]
@@ -300,7 +301,8 @@
return len(self.index_ds)
def __getitem__(self, index):
- # import pdb;
+ # import pdb
+ #
# pdb.set_trace()
output = None
@@ -318,13 +320,27 @@
user = item["user"]
assistant = item["assistant"]
- input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
+ input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ )
for i, (system_prompt, user_prompt, target_out) in enumerate(
zip(system, user, assistant)
):
if i >= self.multiturn_num_max:
break
+ if len(input_ids) > self.max_token_length:
+ logging.info(
+ f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
+ )
+ break
+
if i == 0:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
else:
@@ -334,8 +350,10 @@
splits = self.pattern.split(source_input)
source_ids = []
+ fbank_i = []
fbank_mask_i = []
- fbank_beg_i = []
+ fake_token_len_i = 0
+ fbank_beg_i = -1
fbank_lens_i = []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
@@ -361,6 +379,11 @@
frontend=self.frontend,
is_final=True,
) # speech: [b, T, d]
+ if speech_lengths > self.max_source_length:
+ logging.info(
+ f"speech_lengths > max_source_length: {speech_lengths}>{self.max_source_length}, {item}"
+ )
+ badcase_flag = True
if self.permute:
speech = speech.permute(0, 2, 1)
# if speech_lengths > self.batch_size:
@@ -368,44 +391,45 @@
olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
- sub_token_len = (olens - 1) // 2 + 1
- sub_token = [0] * sub_token_len
- fbank_beg_i = [len(source_ids)]
- source_ids += sub_token
- fbank_mask_i += [1] * len(sub_token)
+ fake_token_len_i = (olens - 1) // 2 + 1
+ fake_token = [0] * fake_token_len_i
+ fbank_beg_i = len(source_ids)
+ source_ids += fake_token
+ fbank_mask_i += [1] * len(fake_token)
if badcase_flag:
continue
+
+ fbank_beg += [fbank_beg_i + len(input_ids)]
+ fake_token_len += [fake_token_len_i]
source_mask = [-100] * len(source_ids)
target_out = f"{target_out}<|im_end|>"
target_ids = self.tokenizer.encode(target_out)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
- fbank.append(speech)
+ fbank.append(speech[0, :, :])
fbank_mask += fbank_mask_i
- fbank_beg.append(fbank_beg_i)
+ fbank_lens.append(speech_lengths)
- if len(input_ids) > self.max_token_length:
- logging.info(
- f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
- )
- badcase_flag = True
if badcase_flag:
continue
+
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
- fbank = speech[0, :, :]
- fbank_lens = speech_lengths
+ # fbank = speech[0, :, :]
+ # fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
+ fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
output = {
"speech": fbank,
"speech_lengths": fbank_lens,
"fbank_mask": fbank_mask,
"fbank_beg": fbank_beg,
+ "fake_token_len": fake_token_len,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels_ids": labels,
@@ -426,7 +450,10 @@
for key in sample.keys():
if key not in outputs:
outputs[key] = []
- outputs[key].append(sample[key])
+ if isinstance(sample[key], (list, tuple)):
+ outputs[key].extend(sample[key])
+ else:
+ outputs[key].append(sample[key])
for key, data_list in outputs.items():
if isinstance(data_list[0], torch.Tensor):
--
Gitblit v1.9.1