[pytree] freeze attributes of TreeSpec (#124011)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124011
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan 2024-05-21 17:32:01 +00:00 committed by PyTorch MergeBot
parent 6edf989e2f
commit 3b0f6cce5c

View File

@ -662,7 +662,7 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
# context: some context that is useful in unflattening the pytree # context: some context that is useful in unflattening the pytree
# children_specs: specs for each child of the root Node # children_specs: specs for each child of the root Node
# num_leaves: the number of leaves # num_leaves: the number of leaves
@dataclasses.dataclass @dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
class TreeSpec: class TreeSpec:
type: Any type: Any
context: Context context: Context
@ -673,9 +673,12 @@ class TreeSpec:
num_children: int = dataclasses.field(init=False) num_children: int = dataclasses.field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.num_nodes = 1 + sum(spec.num_nodes for spec in self.children_specs) num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
self.num_leaves = sum(spec.num_leaves for spec in self.children_specs) num_leaves = sum(spec.num_leaves for spec in self.children_specs)
self.num_children = len(self.children_specs) num_children = len(self.children_specs)
object.__setattr__(self, "num_nodes", num_nodes)
object.__setattr__(self, "num_leaves", num_leaves)
object.__setattr__(self, "num_children", num_children)
def __repr__(self, indent: int = 0) -> str: def __repr__(self, indent: int = 0) -> str:
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
@ -808,9 +811,9 @@ class LeafSpec(TreeSpec):
super().__init__(None, None, []) super().__init__(None, None, [])
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.num_nodes = 1 object.__setattr__(self, "num_nodes", 1)
self.num_leaves = 1 object.__setattr__(self, "num_leaves", 1)
self.num_children = 0 object.__setattr__(self, "num_children", 0)
def __repr__(self, indent: int = 0) -> str: def __repr__(self, indent: int = 0) -> str:
return "*" return "*"