游雁
2023-11-23 723488d97b256a2682af3bf8eb8a8da2c1a6990d
funasr/datasets/dataloader_fn.py
@@ -1,4 +1,4 @@
import time
import torch
from funasr.datasets.dataset_jsonl import AudioDataset
from funasr.datasets.data_sampler import BatchSampler
@@ -8,7 +8,7 @@
collate_fn = None
# collate_fn = collate_fn,
jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
jsonl = "/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl"
frontend = WavFrontend()
token_type = 'char'
@@ -26,7 +26,7 @@
    non_linguistic_symbols=non_linguistic_symbols,
    g2p_type=g2p_type,
)
token_list = ""
token_list = "/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt"
unk_symbol = "<unk>"
token_id_converter = TokenIDConverter(
@@ -34,20 +34,33 @@
    unk_symbol=unk_symbol,
)
dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer, token_id_converter=token_id_converter)
batch_sampler = BatchSampler(dataset)
dataloader_tr = torch.utils.data.DataLoader(dataset,
                           collate_fn=dataset.collator,
                           batch_sampler=batch_sampler,
                           shuffle=False,
                           num_workers=0,
                           pin_memory=True)
print(len(dataset))
for i in range(3):
    print(i)
    for data in dataloader_tr:
        print(len(data), data)
# data_iter = iter(dataloader_tr)
# data = next(data_iter)
pass
def collator(samples: list = None):
    return samples
if __name__ == "__main__":
    dataloader_tr = torch.utils.data.DataLoader(dataset,
                                                collate_fn=dataset.collator,
                                                batch_sampler=batch_sampler,
                                                shuffle=False,
                                                num_workers=8,
                                                pin_memory=True)
    print(len(dataset))
    for i in range(3):
        print(i)
        beg = time.time()
        for j, data in enumerate(dataloader_tr):
            end = time.time()
            time_cost = end - beg
            beg = end
            print(j, time_cost)
    # data_iter = iter(dataloader_tr)
    # data = next(data_iter)
    pass