From 79d12479db25a1a845f40da636bc6a9ecec7bf7e Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 25 一月 2024 11:26:14 +0800
Subject: [PATCH] update run_server.sh
---
funasr/datasets/audio_datasets/samplers.py | 7 ++++---
1 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py
index bc71b28..0d93098 100644
--- a/funasr/datasets/audio_datasets/samplers.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -13,6 +13,7 @@
buffer_size: int = 30,
drop_last: bool = False,
shuffle: bool = True,
+ is_training: bool = True,
**kwargs):
self.drop_last = drop_last
@@ -20,14 +21,14 @@
self.dataset = dataset
self.total_samples = len(dataset)
self.batch_type = batch_type
- self.batch_size = batch_size
+ self.batch_size = int(batch_size)
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
- self.shuffle = shuffle
+ self.shuffle = shuffle and is_training
def __len__(self):
- return self.total_samples
+ return (self.total_samples-1) // self.batch_size + 1
def set_epoch(self, epoch):
np.random.seed(epoch)
--
Gitblit v1.9.1