Add torch.utils._pytree.register_dataclass (#146059)

This is an API that registers a dataclass as a pytree node.
It directly calls torch.export.register_dataclass, but we should
eventually inline that implementation here. I want to use this API for
something in compile and feel weird calling
torch.export.register_dataclass.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146059
Approved by: https://github.com/StrongerXi, https://github.com/angelayi, https://github.com/yanboliang
This commit is contained in:
rzou 2025-01-31 09:04:25 -08:00 committed by PyTorch MergeBot
parent 2fd1b6b361
commit 373606928b
2 changed files with 48 additions and 0 deletions

View File

@ -1182,6 +1182,19 @@ TreeSpec(tuple, None, [*,
)
self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
def test_dataclass(self):
@dataclass
class Point:
x: torch.Tensor
y: torch.Tensor
py_pytree.register_dataclass(Point)
point = Point(torch.tensor(0), torch.tensor(1))
point = py_pytree.tree_map(lambda x: x + 1, point)
self.assertEqual(point.x, torch.tensor(1))
self.assertEqual(point.y, torch.tensor(2))
def test_tree_map_with_path_multiple_trees(self):
@dataclass
class ACustomPytree:

View File

@ -255,6 +255,41 @@ def register_pytree_node(
_cxx_pytree_pending_imports.append((args, kwargs))
def register_dataclass(cls: type[Any]) -> None:
"""Registers a ``dataclasses.dataclass`` type as a pytree node.
This is a simpler API than :func:`register_pytree_node` for registering
a dataclass.
Args:
cls: the dataclass type to register
Example:
>>> from torch import Tensor
>>> from dataclasses import dataclass
>>> import torch.utils._pytree as pytree
>>>
>>> @dataclass
>>> class Point:
>>> x: Tensor
>>> y: Tensor
>>>
>>> pytree.register_dataclass(Point)
>>>
>>> point = Point(torch.tensor(0), torch.tensor(1))
>>> point = pytree.tree_map(lambda x: x + 1, point)
>>> assert torch.allclose(point.x, torch.tensor(1))
>>> assert torch.allclose(point.y, torch.tensor(2))
"""
import torch.export
# Eventually we should move the export code here. It is not specific to export,
# aside from the serialization pieces.
torch.export.register_dataclass(cls)
def _register_namedtuple(
cls: type[Any],
*,