mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torch.utils._pytree.register_dataclass (#146059)
This is an API that registers a dataclass as a pytree node. It directly calls torch.export.register_dataclass, but we should eventually inline that implementation here. I want to use this API for something in compile and feel weird calling torch.export.register_dataclass. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/146059 Approved by: https://github.com/StrongerXi, https://github.com/angelayi, https://github.com/yanboliang
This commit is contained in:
parent
2fd1b6b361
commit
373606928b
|
|
@ -1182,6 +1182,19 @@ TreeSpec(tuple, None, [*,
|
|||
)
|
||||
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
|
||||
|
||||
def test_dataclass(self):
|
||||
@dataclass
|
||||
class Point:
|
||||
x: torch.Tensor
|
||||
y: torch.Tensor
|
||||
|
||||
py_pytree.register_dataclass(Point)
|
||||
|
||||
point = Point(torch.tensor(0), torch.tensor(1))
|
||||
point = py_pytree.tree_map(lambda x: x + 1, point)
|
||||
self.assertEqual(point.x, torch.tensor(1))
|
||||
self.assertEqual(point.y, torch.tensor(2))
|
||||
|
||||
def test_tree_map_with_path_multiple_trees(self):
|
||||
@dataclass
|
||||
class ACustomPytree:
|
||||
|
|
|
|||
|
|
@ -255,6 +255,41 @@ def register_pytree_node(
|
|||
_cxx_pytree_pending_imports.append((args, kwargs))
|
||||
|
||||
|
||||
def register_dataclass(cls: type[Any]) -> None:
|
||||
"""Registers a ``dataclasses.dataclass`` type as a pytree node.
|
||||
|
||||
This is a simpler API than :func:`register_pytree_node` for registering
|
||||
a dataclass.
|
||||
|
||||
Args:
|
||||
cls: the dataclass type to register
|
||||
|
||||
Example:
|
||||
|
||||
>>> from torch import Tensor
|
||||
>>> from dataclasses import dataclass
|
||||
>>> import torch.utils._pytree as pytree
|
||||
>>>
|
||||
>>> @dataclass
|
||||
>>> class Point:
|
||||
>>> x: Tensor
|
||||
>>> y: Tensor
|
||||
>>>
|
||||
>>> pytree.register_dataclass(Point)
|
||||
>>>
|
||||
>>> point = Point(torch.tensor(0), torch.tensor(1))
|
||||
>>> point = pytree.tree_map(lambda x: x + 1, point)
|
||||
>>> assert torch.allclose(point.x, torch.tensor(1))
|
||||
>>> assert torch.allclose(point.y, torch.tensor(2))
|
||||
|
||||
"""
|
||||
import torch.export
|
||||
|
||||
# Eventually we should move the export code here. It is not specific to export,
|
||||
# aside from the serialization pieces.
|
||||
torch.export.register_dataclass(cls)
|
||||
|
||||
|
||||
def _register_namedtuple(
|
||||
cls: type[Any],
|
||||
*,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user