From 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 14:32:43 +0800
Subject: [PATCH] batch

---
 funasr/tokenizer/abs_tokenizer.py           |    2 +-
 funasr/datasets/audio_datasets/scp2jsonl.py |   13 ++++++++++---
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index f6ceb69..f167173 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -7,6 +7,7 @@
 import concurrent.futures
 import librosa
 import torch.distributed as dist
+from tqdm import tqdm
 
 
 def gen_jsonl_from_wav_text_list(
@@ -28,6 +29,7 @@
             with open(data_file, "r") as f:
 
                 data_file_lists = f.readlines()
+                print("")
                 lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                 task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                 # import pdb;pdb.set_trace()
@@ -41,6 +43,7 @@
                                     i * lines_for_each_th : (i + 1) * lines_for_each_th
                                 ],
                                 data_type,
+                                i,
                             )
                             for i in range(task_num)
                         ]
@@ -69,11 +72,15 @@
         dist.barrier()
 
 
-def parse_context_length(data_list: list, data_type: str):
-
+def parse_context_length(data_list: list, data_type: str, id=0):
+    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
     res = {}
     for i, line in enumerate(data_list):
-        key, line = line.strip().split(maxsplit=1)
+        pbar.update(1)
+        pbar.set_description(f"cpu: {id}")
+        lines = line.strip().split(maxsplit=1)
+        key = lines[0]
+        line = lines[1] if len(lines) > 1 else ""
         line = line.strip()
         if os.path.exists(line):
             waveform, _ = librosa.load(line, sr=16000)
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index a629e94..e125d29 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -62,7 +62,7 @@
                 raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
             self.unk_id = self.token2id[self.unk_symbol]
 
-    def encode(self, text):
+    def encode(self, text, **kwargs):
         tokens = self.text2tokens(text)
         text_ints = self.tokens2ids(tokens)
 

--
Gitblit v1.9.1