游雁
2024-04-29 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca
batch
2个文件已修改
15 ■■■■ 已修改文件
funasr/datasets/audio_datasets/scp2jsonl.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/abs_tokenizer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/scp2jsonl.py
@@ -7,6 +7,7 @@
import concurrent.futures
import librosa
import torch.distributed as dist
from tqdm import tqdm
def gen_jsonl_from_wav_text_list(
@@ -28,6 +29,7 @@
            with open(data_file, "r") as f:
                data_file_lists = f.readlines()
                print("")
                lines_for_each_th = (len(data_file_lists) - 1) // cpu_cores + 1
                task_num = cpu_cores if len(data_file_lists) > cpu_cores else 1
                # import pdb;pdb.set_trace()
@@ -41,6 +43,7 @@
                                    i * lines_for_each_th : (i + 1) * lines_for_each_th
                                ],
                                data_type,
                                i,
                            )
                            for i in range(task_num)
                        ]
@@ -69,11 +72,15 @@
        dist.barrier()
def parse_context_length(data_list: list, data_type: str):
def parse_context_length(data_list: list, data_type: str, id=0):
    pbar = tqdm(total=len(data_list), dynamic_ncols=True)
    res = {}
    for i, line in enumerate(data_list):
        key, line = line.strip().split(maxsplit=1)
        pbar.update(1)
        pbar.set_description(f"cpu: {id}")
        lines = line.strip().split(maxsplit=1)
        key = lines[0]
        line = lines[1] if len(lines) > 1 else ""
        line = line.strip()
        if os.path.exists(line):
            waveform, _ = librosa.load(line, sr=16000)
funasr/tokenizer/abs_tokenizer.py
@@ -62,7 +62,7 @@
                raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
            self.unk_id = self.token2id[self.unk_symbol]
    def encode(self, text):
    def encode(self, text, **kwargs):
        tokens = self.text2tokens(text)
        text_ints = self.tokens2ids(tokens)