From 723488d97b256a2682af3bf8eb8a8da2c1a6990d Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 十一月 2023 16:16:20 +0800
Subject: [PATCH] funasr v2

---
 funasr/datasets/dataloader_fn.py |   49 +++++++++++++++++++++++++++++++------------------
 1 files changed, 31 insertions(+), 18 deletions(-)

diff --git a/funasr/datasets/dataloader_fn.py b/funasr/datasets/dataloader_fn.py
index 8e8e423..3393a33 100644
--- a/funasr/datasets/dataloader_fn.py
+++ b/funasr/datasets/dataloader_fn.py
@@ -1,4 +1,4 @@
-
+import time
 import torch
 from funasr.datasets.dataset_jsonl import AudioDataset
 from funasr.datasets.data_sampler import BatchSampler
@@ -8,7 +8,7 @@
 collate_fn = None
 # collate_fn = collate_fn,
 
-jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
+jsonl = "/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl"
 
 frontend = WavFrontend()
 token_type = 'char'
@@ -26,7 +26,7 @@
     non_linguistic_symbols=non_linguistic_symbols,
     g2p_type=g2p_type,
 )
-token_list = ""
+token_list = "/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt"
 unk_symbol = "<unk>"
 
 token_id_converter = TokenIDConverter(
@@ -34,20 +34,33 @@
     unk_symbol=unk_symbol,
 )
 
-dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
+dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer, token_id_converter=token_id_converter)
 batch_sampler = BatchSampler(dataset)
-dataloader_tr = torch.utils.data.DataLoader(dataset,
-                           collate_fn=dataset.collator,
-                           batch_sampler=batch_sampler,
-                           shuffle=False,
-                           num_workers=0,
-                           pin_memory=True)
 
-print(len(dataset))
-for i in range(3):
-    print(i)
-    for data in dataloader_tr:
-        print(len(data), data)
-# data_iter = iter(dataloader_tr)
-# data = next(data_iter)
-pass
+
+def collator(samples: list = None):
+    return samples
+
+if __name__ == "__main__":
+    
+    dataloader_tr = torch.utils.data.DataLoader(dataset,
+                                                collate_fn=dataset.collator,
+                                                batch_sampler=batch_sampler,
+                                                shuffle=False,
+                                                num_workers=8,
+                                                pin_memory=True)
+    
+    print(len(dataset))
+    for i in range(3):
+        print(i)
+        beg = time.time()
+        for j, data in enumerate(dataloader_tr):
+            end = time.time()
+            time_cost = end - beg
+            beg = end
+            print(j, time_cost)
+    # data_iter = iter(dataloader_tr)
+    # data = next(data_iter)
+    pass
+
+    
\ No newline at end of file

--
Gitblit v1.9.1