| | |
| | | import numpy as np |
| | | import kaldiio |
| | | import librosa |
| | | |
| | | |
| | | import torchaudio |
| | | import time |
| | | |
| | | def load_audio(audio_path: str, fs: int=16000): |
| | | audio = None |
| | |
| | | if ".ark:" in audio_path: |
| | | audio = kaldiio.load_mat(audio_path) |
| | | else: |
| | | audio, fs = librosa.load(audio_path, sr=fs) |
| | | # audio, fs = librosa.load(audio_path, sr=fs) |
| | | audio, fs = torchaudio.load(audio_path) |
| | | audio = audio[0, :] |
| | | return audio |
| | | |
| | | def extract_features(data, date_type: str="sound", frontend=None): |
| | |
| | | |
| | | def __getitem__(self, index): |
| | | return self.contents[index] |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict["source_len"] |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | |
| | | |
| | | class AudioDataset(torch.utils.data.Dataset): |
| | | def __init__(self, path, frontend=None, tokenizer=None, token_id_converter=None): |
| | | |
| | | def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs): |
| | | super().__init__() |
| | | self.indexed_dataset = IndexedDatasetJsonl(path) |
| | | self.frontend = frontend.forward |
| | | self.fs = 16000 if frontend is None else frontend.fs |
| | | self.data_type = "sound" |
| | | self.tokenizer = tokenizer |
| | | self.token_id_converter = token_id_converter |
| | | |
| | | self.int_pad_value = -1 |
| | | self.float_pad_value = 0.0 |
| | | self.int_pad_value = int_pad_value |
| | | self.float_pad_value = float_pad_value |
| | | |
| | | |
| | | |
| | |
| | | data_src = load_audio(source, fs=self.fs) |
| | | speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend) |
| | | target = item["target"] |
| | | text = self.tokenizer.text2tokens(target) |
| | | ids = self.token_id_converter.tokens2ids(text) |
| | | ids = self.tokenizer.encode(target) |
| | | ids_lengths = len(ids) |
| | | text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32) |
| | | |