| | |
| | | output = (uttids, output) |
| | | assert check_return_type(output) |
| | | return output |
| | | |
| | | def crop_to_max_size(feature, target_size): |
| | | size = len(feature) |
| | | diff = size - target_size |
| | | if diff <= 0: |
| | | return feature |
| | | |
| | | start = np.random.randint(0, diff + 1) |
| | | end = size - diff + start |
| | | return feature[start:end] |
| | | |
| | | |
| | | def clipping_collate_fn( |
| | | data: Collection[Tuple[str, Dict[str, np.ndarray]]], |
| | | max_sample_size=None, |
| | | not_sequence: Collection[str] = (), |
| | | ) -> Tuple[List[str], Dict[str, torch.Tensor]]: |
| | | # mainly for pre-training |
| | | assert check_argument_types() |
| | | uttids = [u for u, _ in data] |
| | | data = [d for _, d in data] |
| | | |
| | | assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" |
| | | assert all( |
| | | not k.endswith("_lengths") for k in data[0] |
| | | ), f"*_lengths is reserved: {list(data[0])}" |
| | | |
| | | output = {} |
| | | for key in data[0]: |
| | | array_list = [d[key] for d in data] |
| | | tensor_list = [torch.from_numpy(a) for a in array_list] |
| | | sizes = [len(s) for s in tensor_list] |
| | | if max_sample_size is None: |
| | | target_size = min(sizes) |
| | | else: |
| | | target_size = min(min(sizes), max_sample_size) |
| | | tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1]) |
| | | for i, (source, size) in enumerate(zip(tensor_list, sizes)): |
| | | diff = size - target_size |
| | | if diff == 0: |
| | | tensor[i] = source |
| | | else: |
| | | tensor[i] = crop_to_max_size(source, target_size) |
| | | output[key] = tensor |
| | | |
| | | if key not in not_sequence: |
| | | lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long) |
| | | output[key + "_lengths"] = lens |
| | | |
| | | output = (uttids, output) |
| | | assert check_return_type(output) |
| | | return output |