| | |
| | | def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): |
| | | """Change the device of object recursively""" |
| | | if isinstance(data, dict): |
| | | return { |
| | | k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() |
| | | } |
| | | return {k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()} |
| | | elif dataclasses.is_dataclass(data) and not isinstance(data, type): |
| | | return type(data)( |
| | | *[ |
| | | to_device(v, device, dtype, non_blocking, copy) |
| | | for v in dataclasses.astuple(data) |
| | | ] |
| | | *[to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data)] |
| | | ) |
| | | # maybe namedtuple. I don't know the correct way to judge namedtuple. |
| | | elif isinstance(data, tuple) and type(data) is not tuple: |
| | | return type(data)( |
| | | *[to_device(o, device, dtype, non_blocking, copy) for o in data] |
| | | ) |
| | | return type(data)(*[to_device(o, device, dtype, non_blocking, copy) for o in data]) |
| | | elif isinstance(data, (list, tuple)): |
| | | return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) |
| | | elif isinstance(data, np.ndarray): |