From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch

---
 funasr/models/eend/eend_ola_dataloader.py |   24 ++++++++++++------------
 1 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/funasr/models/eend/eend_ola_dataloader.py b/funasr/models/eend/eend_ola_dataloader.py
index 2ee9272..43f00fb 100644
--- a/funasr/models/eend/eend_ola_dataloader.py
+++ b/funasr/models/eend/eend_ola_dataloader.py
@@ -12,17 +12,15 @@
     speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
     speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
     orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
-    batch = dict(speech=speech,
-                 speaker_labels=speaker_labels,
-                 orders=orders)
+    batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders)
 
     return keys, batch
 
 
 class EENDOLADataset(Dataset):
     def __init__(
-            self,
-            data_file,
+        self,
+        data_file,
     ):
         self.data_file = data_file
         with open(data_file) as f:
@@ -44,14 +42,16 @@
         return key, speech, speaker_label, order
 
 
-class EENDOLADataLoader():
+class EENDOLADataLoader:
     def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
         dataset = EENDOLADataset(data_file)
-        self.data_loader = DataLoader(dataset,
-                                      batch_size=batch_size,
-                                      collate_fn=custom_collate,
-                                      shuffle=shuffle,
-                                      num_workers=num_workers)
+        self.data_loader = DataLoader(
+            dataset,
+            batch_size=batch_size,
+            collate_fn=custom_collate,
+            shuffle=shuffle,
+            num_workers=num_workers,
+        )
 
     def build_iter(self, epoch):
-        return self.data_loader
\ No newline at end of file
+        return self.data_loader

--
Gitblit v1.9.1