From 8912e0696af069de47646fdb8a9d9c4e086e88b3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 14 一月 2024 23:42:11 +0800
Subject: [PATCH] Resolve merge conflict
---
funasr/datasets/audio_datasets/samplers.py | 29 ++++++++++++++---------------
1 files changed, 14 insertions(+), 15 deletions(-)
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index d34fdea..9c87245 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -4,15 +4,16 @@
from funasr.register import tables
+
@tables.register("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
class BatchSampler(torch.utils.data.BatchSampler):
def __init__(self, dataset,
- batch_type: str="example",
- batch_size: int=100,
- buffer_size: int=30,
- drop_last: bool=False,
- shuffle: bool=True,
+ batch_type: str = "example",
+ batch_size: int = 100,
+ buffer_size: int = 30,
+ drop_last: bool = False,
+ shuffle: bool = True,
**kwargs):
self.drop_last = drop_last
@@ -25,24 +26,23 @@
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle
-
def __len__(self):
return self.total_samples
def set_epoch(self, epoch):
np.random.seed(epoch)
-
+
def __iter__(self):
if self.shuffle:
np.random.shuffle(self.shuffle_idx)
-
+
batch = []
max_token = 0
num_sample = 0
-
- iter_num = (self.total_samples-1) // self.buffer_size + 1
+
+ iter_num = (self.total_samples - 1) // self.buffer_size + 1
# print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
@@ -50,12 +50,12 @@
idx = iter * self.buffer_size + i
if idx >= self.total_samples:
continue
-
+
idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.get_source_len(idx_map) + \
self.dataset.get_target_len(idx_map)
-
+
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
@@ -63,7 +63,7 @@
idx, sample_len_cur_raw = item
if sample_len_cur_raw > self.max_token_length:
continue
-
+
max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_type == 'length':
@@ -77,5 +77,4 @@
batch = [idx]
max_token = sample_len_cur_raw
num_sample = 1
-
-
\ No newline at end of file
+
--
Gitblit v1.9.1