| New file |
| | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.datasets.large_datasets.collate_fn import crop_to_max_size |
| | | |
| | | |
| | | def clipping(data): |
| | | assert isinstance(data, list) |
| | | assert "key" in data[0] |
| | | |
| | | keys = [x["key"] for x in data] |
| | | |
| | | batch = {} |
| | | data_names = data[0].keys() |
| | | for data_name in data_names: |
| | | if data_name == "key": |
| | | continue |
| | | else: |
| | | if data[0][data_name].dtype.kind == "i": |
| | | tensor_type = torch.int64 |
| | | else: |
| | | tensor_type = torch.float32 |
| | | |
| | | tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data] |
| | | tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32) |
| | | |
| | | length_clip = min(tensor_lengths) |
| | | tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1]) |
| | | for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)): |
| | | diff = length - length_clip |
| | | assert diff >= 0 |
| | | if diff == 0: |
| | | tensor_clip[i] = tensor |
| | | else: |
| | | tensor_clip[i] = crop_to_max_size(tensor, length_clip) |
| | | |
| | | batch[data_name] = tensor_clip |
| | | batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long) |
| | | |
| | | return keys, batch |