| funasr/datasets/collate_fn.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/datasets/large_datasets/dataset.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/datasets/large_datasets/utils/filter.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/datasets/collate_fn.py
@@ -80,4 +80,56 @@ output = (uttids, output) assert check_return_type(output) return output def crop_to_max_size(feature, target_size): size = len(feature) diff = size - target_size if diff <= 0: return feature start = np.random.randint(0, diff + 1) end = size - diff + start return feature[start:end] def clipping_collate_fn( data: Collection[Tuple[str, Dict[str, np.ndarray]]], max_sample_size=None, not_sequence: Collection[str] = (), ) -> Tuple[List[str], Dict[str, torch.Tensor]]: # mainly for pre-training assert check_argument_types() uttids = [u for u, _ in data] data = [d for _, d in data] assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" assert all( not k.endswith("_lengths") for k in data[0] ), f"*_lengths is reserved: {list(data[0])}" output = {} for key in data[0]: array_list = [d[key] for d in data] tensor_list = [torch.from_numpy(a) for a in array_list] sizes = [len(s) for s in tensor_list] if max_sample_size is None: target_size = min(sizes) else: target_size = min(min(sizes), max_sample_size) tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1]) for i, (source, size) in enumerate(zip(tensor_list, sizes)): diff = size - target_size if diff == 0: tensor[i] = source else: tensor[i] = crop_to_max_size(source, target_size) output[key] = tensor if key not in not_sequence: lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long) output[key + "_lengths"] = lens output = (uttids, output) assert check_return_type(output) return output funasr/datasets/large_datasets/dataset.py
@@ -102,6 +102,8 @@ elif data_type == "text" or data_type == "sound": text_reader = open(data_file, "r") reader_list.append(text_reader) elif data_type == "none": continue else: raise TypeError("Data type {} is not supported".format(data_type)) funasr/datasets/large_datasets/utils/filter.py
@@ -6,13 +6,21 @@ speech_length_max=15000, token_length_min=0, token_length_max=200): assert "speech" in data assert "text" in data assert "speech" in data or "text" in data if "sampling_rate" in data: speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000. if "speech" in data and "text" in data: if "sampling_rate" in data: speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000. else: speech_length = data["speech"].shape[0] num_tokens = len(data['text']) return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max elif "speech" in data: if "sampling_rate" in data: speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000. else: speech_length = data["speech"].shape[0] return speech_length_min < speech_length < speech_length_max else: speech_length = data["speech"].shape[0] num_tokens = len(data['text']) return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max num_tokens = len(data['text']) return token_length_min < num_tokens < token_length_max