shixian.shi
2023-05-04 1988fe85f6d4e2d2f809e705e13d69d0b57bd0fc
funasr/datasets/large_datasets/utils/padding.py
@@ -13,16 +13,16 @@
    batch = {}
    data_names = data[0].keys()
    for data_name in data_names:
        if data_name == "key" or data_name == "sampling_rate" or data_name == 'hotword_indxs':
            batch[data_name] = data[0][data_name]
        if data_name == "key" or data_name == "sampling_rate":
            continue
        else:
            if data[0][data_name].dtype.kind == "i":
                pad_value = int_pad_value
                tensor_type = torch.int64
            else:
                pad_value = float_pad_value
                tensor_type = torch.float32
            if data_name != 'hotword_indxs':
                if data[0][data_name].dtype.kind == "i":
                    pad_value = int_pad_value
                    tensor_type = torch.int64
                else:
                    pad_value = float_pad_value
                    tensor_type = torch.float32
            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)