| | |
| | | import torch |
| | | import json |
| | | import torch.distributed as dist |
| | | import numpy as np |
| | | import kaldiio |
| | | import librosa |
| | | |
| | | class AudioDatasetJsonl(torch.utils.data.Dataset): |
| | | |
| | | |
| | | def load_audio(audio_path: str, fs: int=16000): |
| | | audio = None |
| | | if audio_path.startswith("oss:"): |
| | | pass |
| | | elif audio_path.startswith("odps:"): |
| | | pass |
| | | else: |
| | | if ".ark:" in audio_path: |
| | | audio = kaldiio.load_mat(audio_path) |
| | | else: |
| | | audio, fs = librosa.load(audio_path, sr=fs) |
| | | return audio |
| | | |
| | | def extract_features(data, date_type: str="sound", frontend=None): |
| | | if date_type == "sound": |
| | | feat, feats_lens = frontend(data, len(data)) |
| | | feat = feat[0, :, :] |
| | | else: |
| | | feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32) |
| | | return feat, feats_lens |
| | | |
| | | def __init__(self, path, data_parallel_rank=0, data_parallel_size=1): |
| | | |
| | | |
| | | class IndexedDatasetJsonl(torch.utils.data.Dataset): |
| | | |
| | | def __init__(self, path): |
| | | super().__init__() |
| | | data_parallel_size = dist.get_world_size() |
| | | # data_parallel_size = dist.get_world_size() |
| | | data_parallel_size = 1 |
| | | contents = [] |
| | | with open(path, encoding='utf-8') as fin: |
| | | for line in fin: |
| | |
| | | self.contents = [] |
| | | total_num = len(contents) |
| | | num_per_rank = total_num // data_parallel_size |
| | | rank = dist.get_rank() |
| | | # rank = dist.get_rank() |
| | | rank = 0 |
| | | # import ipdb; ipdb.set_trace() |
| | | self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | |
| | |
| | | |
| | | def __getitem__(self, index): |
| | | return self.contents[index] |
| | | |
| | | |
| | | class AudioDataset(torch.utils.data.Dataset): |
| | | def __init__(self, path, frontend=None, tokenizer=None): |
| | | super().__init__() |
| | | self.indexed_dataset = IndexedDatasetJsonl(path) |
| | | self.frontend = frontend.forward |
| | | self.fs = 16000 if frontend is None else frontend.fs |
| | | self.data_type = "sound" |
| | | self.tokenizer = tokenizer |
| | | self.int_pad_value = -1 |
| | | self.float_pad_value = 0.0 |
| | | |
| | | |
| | | |
| | | |
| | | def __len__(self): |
| | | return len(self.indexed_dataset) |
| | | |
| | | def __getitem__(self, index): |
| | | item = self.indexed_dataset[index] |
| | | source = item["source"] |
| | | data_src = load_audio(source, fs=self.fs) |
| | | speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend) |
| | | target = item["target"] |
| | | text = self.tokenizer.encode(target) |
| | | text_lengths = len(text) |
| | | text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32) |
| | | return {"speech": speech, |
| | | "speech_lengths": speech_lengths, |
| | | "text": text, |
| | | "text_lengths": text_lengths, |
| | | } |
| | | |
| | | |
| | | def collator(self, samples: list=None): |
| | | |
| | | outputs = {} |
| | | for sample in samples: |
| | | for key in sample.keys(): |
| | | if key not in outputs: |
| | | outputs[key] = [] |
| | | outputs[key].append(sample[key]) |
| | | |
| | | for key, data_list in outputs.items(): |
| | | if data_list[0].dtype.kind == "i": |
| | | pad_value = self.int_pad_value |
| | | else: |
| | | pad_value = self.float_pad_value |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value) |
| | | return samples |