| | |
| | | batch = {} |
| | | data_names = data[0].keys() |
| | | for data_name in data_names: |
| | | if data_name == "key" or data_name == "sampling_rate" or data_name == 'hotword_indxs': |
| | | batch[data_name] = data[0][data_name] |
| | | if data_name == "key" or data_name == "sampling_rate": |
| | | continue |
| | | else: |
| | | if data[0][data_name].dtype.kind == "i": |
| | | pad_value = int_pad_value |
| | | tensor_type = torch.int64 |
| | | else: |
| | | pad_value = float_pad_value |
| | | tensor_type = torch.float32 |
| | | if data_name != 'hotword_indxs': |
| | | if data[0][data_name].dtype.kind == "i": |
| | | pad_value = int_pad_value |
| | | tensor_type = torch.int64 |
| | | else: |
| | | pad_value = float_pad_value |
| | | tensor_type = torch.float32 |
| | | |
| | | tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data] |
| | | tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32) |