From 0033151b6286e9d17a9b91567ba649ff14a89464 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 17 六月 2024 00:55:42 +0800
Subject: [PATCH] decoding
---
funasr/datasets/openai_datasets/datasets.py | 36 ++++++++++++++++++++++++------------
1 files changed, 24 insertions(+), 12 deletions(-)
diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py
index 3c2a957..ae9f289 100644
--- a/funasr/datasets/openai_datasets/datasets.py
+++ b/funasr/datasets/openai_datasets/datasets.py
@@ -300,8 +300,9 @@
return len(self.index_ds)
def __getitem__(self, index):
- # import pdb;
- # pdb.set_trace()
+ import pdb
+
+ pdb.set_trace()
output = None
@@ -318,7 +319,15 @@
user = item["user"]
assistant = item["assistant"]
- input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg = [], [], [], [], [], []
+ input_ids, labels, fbank, fbank_lens, fbank_mask, fbank_beg, fake_token_len = (
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ [],
+ )
for i, (system_prompt, user_prompt, target_out) in enumerate(
zip(system, user, assistant)
@@ -336,7 +345,8 @@
source_ids = []
fbank_i = []
fbank_mask_i = []
- fbank_beg_i = []
+ fake_token_len_i = 0
+ fbank_beg_i = -1
fbank_lens_i = []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
@@ -369,14 +379,17 @@
olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
- sub_token_len = (olens - 1) // 2 + 1
- sub_token = [0] * sub_token_len
- fbank_beg_i = [len(source_ids)]
- source_ids += sub_token
- fbank_mask_i += [1] * len(sub_token)
+ fake_token_len_i = (olens - 1) // 2 + 1
+ fake_token = [0] * fake_token_len_i
+ fbank_beg_i = len(source_ids)
+ source_ids += fake_token
+ fbank_mask_i += [1] * len(fake_token)
if badcase_flag:
continue
+
+ fbank_beg += [fbank_beg_i + len(input_ids)]
+ fake_token_len += [fake_token_len_i]
source_mask = [-100] * len(source_ids)
target_out = f"{target_out}<|im_end|>"
target_ids = self.tokenizer.encode(target_out)
@@ -384,9 +397,6 @@
labels += source_mask + target_ids
fbank.append(speech[0, :, :])
fbank_mask += fbank_mask_i
- if len(fbank_beg_i) < 1:
- fbank_beg_i = [-1]
- fbank_beg += fbank_beg_i
if len(input_ids) > self.max_token_length:
logging.info(
@@ -403,12 +413,14 @@
fbank_lens = speech_lengths
fbank_mask = torch.tensor(fbank_mask, dtype=torch.float32)
fbank_beg = torch.tensor(fbank_beg, dtype=torch.int32)
+ fake_token_len = torch.tensor(fake_token_len, dtype=torch.int32)
output = {
"speech": fbank,
"speech_lengths": fbank_lens,
"fbank_mask": fbank_mask,
"fbank_beg": fbank_beg,
+ "fake_token_len": fake_token_len,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels_ids": labels,
--
Gitblit v1.9.1