| | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.datasets.collate_fn import crop_to_max_size |
| | | from funasr.datasets.large_datasets.collate_fn import crop_to_max_size |
| | | |
| | | |
| | | def clipping(data): |
| | |
| | | tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32) |
| | | |
| | | length_clip = min(tensor_lengths) |
| | | tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1]) |
| | | tensor_clip = tensor_list[0].new_zeros( |
| | | len(tensor_list), length_clip, tensor_list[0].shape[1] |
| | | ) |
| | | for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)): |
| | | diff = length - length_clip |
| | | assert diff >= 0 |
| | |
| | | tensor_clip[i] = crop_to_max_size(tensor, length_clip) |
| | | |
| | | batch[data_name] = tensor_clip |
| | | batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long) |
| | | batch[data_name + "_lengths"] = torch.tensor( |
| | | [tensor.shape[0] for tensor in tensor_clip], dtype=torch.long |
| | | ) |
| | | |
| | | return keys, batch |